weighted_dataset_trainer module¶
- class BasicStatefulTrainer(*args, **kwargs)[source]¶
Bases:
zunis.training.weighted_dataset.weighted_dataset_trainer.BasicTrainer
,zunis.training.weighted_dataset.weighted_dataset_trainer.GenericTrainerAPI
- process_loss(loss)[source]¶
Handle invalid losses by reloading checkpoint and handle logging and checkpointing if valid
- restore_checkpoint_except()[source]¶
Try to to restore from a checkpoint as a response to an exception
- train_on_batch(x, px, fx, **kwargs)[source]¶
Train on a batch of points using the saved configuration
- train_on_target_batches_from_posterior(*, f, batch_size, n_batches=1, n_epochs_per_batch, minibatch_size=None, target_posterior, optim)[source]¶
Train over several epochs over several batches with explicit arguments - overriding the current config We save the full config of this run here and therefore have a different config format
- train_step_on_target_batch(x, px, fx, optim, minibatch_size=None)[source]¶
Training function on a fixed batch: iterate once over the whole batch
- Parameters
x (torch.Tensor) – batch of points in target space sampled from some PDF p(x)
px (torch.Tensor) – values of p(x) for the batch
fx (torch.Tensor) – values of the target (un-normalized) PDF
optim (torch.optim.Optimize) – pytorch optimizer object
minibatch_size (None or int) – Optional. Size of each minibatch for gradient steps.
Notes
if minibatch_size is unset (or None), then it is set to the size of the full batch and a single gradient step is taken
- class BasicTrainer(*args, **kwargs)[source]¶
Bases:
better_abc.ABC
Basic trainer implementation: sample points in target space from a fixed distribution and train over a fixed number of epochs on each batch, over a fixed number of batches.
This is the implementation of all training facilities for this training mode at a low level: no automation, tracking, checkpointing etc is performed. No state or history is conserved.
Rationale: the training implementation should be independent from the tracking and from the API
- static generate_target_batch_from_posterior(n_points, f, target_posterior)[source]¶
Generate a batch of training examples in target space from a specified distribution
- Parameters
n_points – size of the batch
f – function to evaluate on the sampled points
target_posterior – distribution from which to sample points in target_space
- Returns
(x,px,fx): sampled points, sampling distribution PDF values, function values
- Return type
tuple of torch.Tensor
- train_on_target_batch(x, px, fx, optim, n_epochs, minibatch_size=None)[source]¶
Training function on a fixed batch: train for a fixed number of epochs on each batch
- Parameters
x – batch of points in target space sampled from some PDF p(x)
px – values of p(x) for the batch
fx – values of the target (un-normalized) PDF
optim – pytorch optimizer object
n_epochs – number of iterations over the full batch
minibatch_size – Optional. Size of each minibatch for gradient steps
Notes
if minibatch_size is unset (or None), then it is set to the size of the full batch and a single gradient step is taken
- train_on_target_batches_from_posterior(*, f, batch_size, n_batches=1, n_epochs_per_batch, minibatch_size=None, target_posterior, optim)[source]¶
- Main training function: iterate over a fixed number of batches sampled in target space and
train for a fixed number of epochs on each batch
- Parameters
f – un-normalized target PDF
batch_size – number of points per batch
n_batches – number of batches to train over
n_epochs_per_batch – number of iterations over each full batch before sampling a new one
minibatch_size – Optional. Size of each minibatch for gradient steps
target_posterior (distribution from which points are sampled in target space) –
optim – pytorch optimizer object
Notes
if minibatch_size is unset (or None), then it is set to the size of the full batch and a single gradient step is taken
- train_step_on_target_batch(x, px, fx, optim, minibatch_size=None)[source]¶
Training function on a fixed batch: iterate once over the whole batch
- Parameters
x (torch.Tensor) – batch of points in target space sampled from some PDF p(x)
px (torch.Tensor) – values of p(x) for the batch
fx (torch.Tensor) – values of the target (un-normalized) PDF
optim (torch.optim.Optimize) – pytorch optimizer object
minibatch_size (None or int) – Optional. Size of each minibatch for gradient steps.
Notes
if minibatch_size is unset (or None), then it is set to the size of the full batch and a single gradient step is taken
- class GenericTrainerAPI(*args, **kwargs)[source]¶
Bases:
better_abc.ABC
Weighted dataset trainer API definition
The goal of this API specification is to interact with the Integrator
- reset()[source]¶
Reinitinalize trainer and model. Optional element of the API. Raises an error if not implemented
- abstract train_on_batch(x, px, fx, **kwargs)[source]¶
Training function
- Parameters
x (-) – batch of points in target space sampled from a distribution p(x)
px (-) – the corresponding batch of p(x) values
fx (-) – the (un-normalized) target weights
kwargs (-) – trainer-specific options. This overrides the config
Notes
This is not enforced at the level of this API but it is intended that kwargs leads to a call of set_config and therefore that these options are saved - at least by default.