variance_training module¶
Optimization of invertible flows in the weighted dataset problem using the DKL loss
Reminder: we have a dataset (x,p(x),f(x)) such that - x ~ p(x) - we want to learn a model that draws points according to f(x), which is positive, and known up to normalization
We want to optimize a function q(x) such that doing importance sampling to compute f(x) with it minimizes the variance.
The variance of the importance sampling estimator is our proto-loss
- pL(f,q) = ∫ dx q(x) (f(x)/q(x))^2 - (∫ dx q(x) f(x)/q(x))^2
= ∫ dx (f(x)^2/q(x)) - I(f)^2
where I(f) is the integral we want to compute and is independent of q, so our real loss is
L(f,q) = ∫ dx f(x)^2/q(x)
Which we further can compute using importance sampling from p(x):
L(f,q) = ∫ dx p(x) f(x)^2/q(x)/p(x)
Which we can compute from our dataset as the expectation value
L(f,q) = E(f(x)^2/(q(x) p(x)), x~p(x)
- class BasicStatefulVarTrainer(*args, **kwargs)[source]¶
Bases:
zunis.training.weighted_dataset.weighted_dataset_trainer.BasicStatefulTrainer
Basic stateful trainer based on the variance loss
- class BasicVarTrainer(*args, **kwargs)[source]¶
Bases:
zunis.training.weighted_dataset.weighted_dataset_trainer.BasicTrainer
Basic trainer based on the variance loss
- weighted_variance_loss(fx, px, logqx)[source]¶
Proxy variance loss for the integral of a function f using importance sampling from q, but where the variance is estimated with importance sampling from p.
We want to optimize a function q(x) such that doing importance sampling to compute f(x) with it minimizes the variance.
The variance of the importance sampling estimator is our proto-loss
- pL(f,q) = ∫ dx q(x) (f(x)/q(x))^2 - (∫ dx q(x) f(x)/q(x))^2
= ∫ dx (f(x)^2/q(x)) - I(f)^2
where I(f) is the integral we want to compute and is independent of q, so our real loss is
L(f,q) = ∫ dx f(x)^2/q(x)
Which we further can compute using importance sampling from p(x):
L(f,q) = ∫ dx p(x) f(x)^2/q(x)/p(x)
Which we can compute from our dataset as the expectation value