stateful_trainer module

Main weighted dataset stateful trainer API class

class StatefulTrainer(*args, **kwargs)[source]

Bases: zunis.training.weighted_dataset.weighted_dataset_trainer.BasicStatefulTrainer

High-level API for stateful trainers using weighted datasets (dataset consisting of tuples of point, function value, point pdf).

Parameters
default_flow_priors = {'pwlinear': <class 'zunis.models.flows.sampling.UniformSampler'>, 'pwquad': <class 'zunis.models.flows.sampling.UniformSampler'>, 'realnvp': <class 'zunis.models.flows.sampling.FactorizedGaussianSampler'>}[source]

Dictionary for the string-based API to define the distribution of the data in latent space based on the choice of coupling cell

flow_priors = {'gaussian': <class 'zunis.models.flows.sampling.FactorizedGaussianSampler'>, 'uniform': <class 'zunis.models.flows.sampling.UniformSampler'>}[source]

Dictionary for the string-based API to define the distribution of the data in latent space

loss_map = {'dkl': <function weighted_dkl_loss>, 'variance': <function weighted_variance_loss>}[source]

Dictionary for the string-based API to define the loss function used in training