general_flow module

Implementation of the abstract GeneralFlow class Most generic variable transformation: - takes in a point x and -log(PDF(x)) - outputs a transformed point y and - log(PDF(y)) = - log(PDF(x)) + log(dy/dx)

Reminder:

dx p(x) = dy q(y) = dx dy/dx q(y) => q(y) = p(x)/(dy/dx) => -log q(y) = -log p(x) + log dy/dx

class GeneralFlow(*args, **kwargs)[source]

Bases: torch.nn.modules.module.Module, better_abc.ABC

General abstract class for flows

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract flow(x)[source]

Transform the batch of points x with shape (…,d) This is an abstract method that should be overriden

forward(xj)[source]

Compute the flow transformation on some input xj - In training mode, xj.shape == (:,d+1) and the last dimension is the log-inverse PDF of x[:,:-1] - In eval mode,, xj.shape == (:,d) and no jacobian is passed: pure sampling mode.

training: bool[source]
abstract transform_and_compute_jacobian(xj)[source]

Compute the flow transformation and its Jacobian simulatenously on xj with xj.shape == (…,d+1)

This is an abstract method that should be overriden