stateful_trainer module¶
Main weighted dataset stateful trainer API class
- class StatefulTrainer(*args, **kwargs)[source]¶
Bases:
zunis.training.weighted_dataset.weighted_dataset_trainer.BasicStatefulTrainer
High-level API for stateful trainers using weighted datasets (dataset consisting of tuples of point, function value, point pdf).
- Parameters
d (int) – dimensionality of the space
loss (str or function) – loss function. If this argument is a string, it is mapped to a function using
zunis.training.weighted_dataset.stateful_trainer.loss_map
flow (str or
zunis.models.flows.general_flow.GeneralFlow
) – if this variable is a string, it is a cell key used inzunis.models.flows.sequential.repeated_cell.RepeatedCellFlow
otherwise it can be an actual flow modelflow_prior (None or str or
zunis.models.flows.sampling.FactorizedFlowSampler
) – PDF used for sampling latent space. If None (default) then use the “natural choice” defined in the class variablezunis.training.weighted_dataset.stateful_trainer.StatefulTrainer.default_flow_priors
A string argument will be mapped usingzunis.training.weighted_dataset.stateful_trainer.StatefulTrainer.flow_priors
flow_options (None or dict) – options to be passed to the
zunis.models.flows.sequential.repeated_cell.RepeatedCellFlow
model ifflow
is a stringprior_options (None or dict) – options to be passed to the latent prior constructor if a “natural choice” prior is used i.e. if
flow_prior
isNone
or astr
device – device on which to run the model and the sampling
n_epochs (int) – number of epochs per batch of data during training
optim (None or torch.optim.Optimizer sublcass) – optimizer to use for training. If none, default Adam is used
- default_flow_priors = {'pwlinear': <class 'zunis.models.flows.sampling.UniformSampler'>, 'pwquad': <class 'zunis.models.flows.sampling.UniformSampler'>, 'realnvp': <class 'zunis.models.flows.sampling.FactorizedGaussianSampler'>}[source]¶
Dictionary for the string-based API to define the distribution of the data in latent space based on the choice of coupling cell