general_flow module¶
Implementation of the abstract GeneralFlow class Most generic variable transformation: - takes in a point x and -log(PDF(x)) - outputs a transformed point y and - log(PDF(y)) = - log(PDF(x)) + log(dy/dx)
Reminder:
dx p(x) = dy q(y) = dx dy/dx q(y) => q(y) = p(x)/(dy/dx) => -log q(y) = -log p(x) + log dy/dx
- class GeneralFlow(*args, **kwargs)[source]¶
Bases:
torch.nn.modules.module.Module
,better_abc.ABC
General abstract class for flows
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- abstract flow(x)[source]¶
Transform the batch of points x with shape (…,d) This is an abstract method that should be overriden