jaxili.model module

Contents

jaxili.model module#

Model.

This module contains classes to implement normalizing flows using neural networks.

class jaxili.model.AffineCoupling(y: ~typing.Any, layers: list, activation: callable, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: Module

Base class for an Affine Coupling layer for RealNVP.

Parameters:
  • y (Any) – Conditionning variable.

  • layers (list) – List of hidden layers size.

  • activation (Callable) – Activation function.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

__call__(x, output_units, **kwargs)

Build the bijector using tensorflow_probability where the scale and the shift are learned by a neural network.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

activation: callable#
layers: list#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
y: Any#
class jaxili.model.Compressor_w_NDE(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: NDENetwork

Base class to create a normalizing flow with a compression of the conditionning variable.

A parent class to implement a compressor followed by a normalizing flow. This is useful to perform Implicit Likelihood Inference in large dimensions where compression is required and can sometimes be done with a normalizing flow.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

compress(x)

Compress the data point x using the compressor.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x[, y])

Log probability of the data point x conditioned by y.

log_prob_from_compressed(z[, y])

Log probability of the data point z conditioned by y.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample(y, num_samples, key)

Sample from the distribution conditioned by y.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

abstract compress(x)[source]#

Compress the data point x using the compressor.

Parameters:

x (jnp.Array) – Data point.

Returns:

Compressed data point.

Return type:

jnp.Array

abstract log_prob(x, y=None, **kwargs)[source]#

Log probability of the data point x conditioned by y.

Parameters:
  • x (jnp.Array) – Data point.

  • y (jnp.Array) – Conditionning variable.

Returns:

Log probability of the data point conditioned by y.

Return type:

jnp.Array

abstract log_prob_from_compressed(z, y=None, **kwargs)[source]#

Log probability of the data point z conditioned by y. z has been previously compressed.

Parameters:
  • z (jnp.Array) – Compressed data point.

  • y (jnp.Array) – Conditionning variable.

Returns:

Log probability of the data point conditioned by y.

Return type:

jnp.Array

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
abstract sample(y, num_samples, key)[source]#

Sample from the distribution conditioned by y.

Parameters:
  • y (jnp.Array) – Conditionning variable.

  • num_samples (int) – Number of samples.

  • key (jnp.Array) – Random key.

Returns:

num_samples samples from the distribution.

Return type:

jnp.Array

scope: Scope | None = None#
class jaxili.model.ConditionalMADE(n_in: int, hidden_dims: list[int], activation: ~typing.Callable, n_cond: int = 0, gaussian: bool = True, random_order: bool = False, seed: int | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: Module

Base class for Conditional Masked Autoencoder Density Estimatior (MADE).

MADE is a neural network that parameterizes the conditional distribution of a random variable using masked linear layers.

Parameters:
  • n_in (int) – Size of the input vector.

  • hidden_dims (list[int]) – List of hidden dimensions.

  • activation (Callable) – Activation function.

  • n_cond (int) – Size of the conditionning variable. 0 if None.

  • gaussian (bool) – Whether the output are mean and variance of a Gaussian conditional. Default True.

  • random_order (bool) – Whether to use random order of the input for masking. Default False.

  • seed (Optional[int]) – Random seed to label nodes. !!Default is None but the MADE will not work unless a seed is applied!!

Attributes:
name
parent
path

Get the path of this Module.

scope
seed
variables

Returns the variables in this module.

Methods

__call__(x[, y])

Forward pass of the model.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Set the network creating the masks and the masked linear layers.

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

activation: Callable#
gaussian: bool = True#
hidden_dims: list[int]#
n_cond: int = 0#
n_in: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
random_order: bool = False#
scope: Scope | None = None#
seed: int | None = None#
setup()[source]#

Set the network creating the masks and the masked linear layers.

class jaxili.model.ConditionalMAF(n_in: int, n_cond: int, n_layers: int, layers: list[int], activation: ~typing.Callable, use_reverse: bool, seed: int | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: NDENetwork

Base class of a Conditional Masked Autoregressive Flow.

A Conditional Masked Autoregressive Flow to model the conditional distribution of a random variable. It is obtained by stacking n_layers MAF layers.

Parameters:
  • n_in (int) – Size of the input vector.

  • n_cond (int) – Size of the conditionning variable.

  • n_layers (int) – Number of layers (i.e. number of stacked MAFs).

  • layers (list[int]) – List of hidden dimensions in each MAF.

  • activation (Callable) – Activation function.

  • use_reverse (bool) – Whether to reverse the order of the input between each MAF.

  • seed (Optional[int]) – Random seed to label nodes. !!Default is None but the MAF will not work unless a seed is applied!!

Attributes:
name
parent
path

Get the path of this Module.

scope
seed
variables

Returns the variables in this module.

Methods

__call__(x[, y])

Forward pass of the model.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

backward(u[, y])

Backward pass of the model.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x[, y])

Compute the log-probability conditionned on some conditionning variable.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample([y, num_samples, key])

Sample from the distribution emulated by the neural network.

setup()

Set the network creating the MAF layers.

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

activation: Callable#
backward(u, y=None)[source]#

Backward pass of the model.

Return vector x transformed by the inverse flow and the log-determinant of the Jacobian of the inverse flow.

Parameters:
  • u (jnp.Array) – Input vector.

  • y (jnp.Array) – Conditionning variable.

Returns:

  • x (jnp.Array) – Transformed vector.

  • log_det_sum (jnp.Array) – Log-determinant of the Jacobian.

layers: list[int]#
log_prob(x, y=None)[source]#

Compute the log-probability conditionned on some conditionning variable.

Parameters:
  • x (jnp.Array) – Input vector.

  • y (jnp.Array) – Conditionning variable.

Returns:

Log probability of the data point.

Return type:

jnp.Array

n_cond: int#
n_in: int#
n_layers: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
sample(y=None, num_samples=1, key=None)[source]#

Sample from the distribution emulated by the neural network.

Parameters:
  • y (jnp.Array) – Conditionning variable.

  • num_samples (int) – Number of samples.

  • key (jnp.Array) – Random key.

Returns:

Samples from the distribution.

Return type:

jnp.Array

scope: Scope | None = None#
seed: int | None = None#
setup()[source]#

Set the network creating the MAF layers.

use_reverse: bool#
class jaxili.model.ConditionalRealNVP(n_in: int, n_cond: int, n_layers: int, layers: list[int], activation: ~typing.Callable, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: NDENetwork

Base class for a Conditional RealNVP.

A Normalizing Flow using RealNVP with a conditionning variable.

Parameters:
  • n_in (int) – Dimension of the input.

  • n_cond (int) – Dimension of the conditionning variable.

  • n_layers (int) – Number of layers.

  • layers (list[int]) – List of hidden layers size.

  • activation (Callable) – Activation function.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

__call__(y, **kwargs)

Build the bijector using tensorflow_probability.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x, y, **kwargs)

Compute the log probability of the data point x conditioned by y from the normalizing flow.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample(y, num_samples, key, **kwargs)

Sample from the distribution mapped by the real NVP.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

activation: Callable#
layers: list[int]#
log_prob(x, y, **kwargs)[source]#

Compute the log probability of the data point x conditioned by y from the normalizing flow.

Parameters:
  • x (jnp.Array) – Data point.

  • y (jnp.Array) – Conditionning variable.

Returns:

Log probability of the data point conditioned by y.

Return type:

jnp.Array

n_cond: int#
n_in: int#
n_layers: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
sample(y, num_samples, key, **kwargs)[source]#

Sample from the distribution mapped by the real NVP.

Parameters:
  • y (jnp.Array) – Conditionning variable.

  • num_samples (int) – Number of samples.

  • key (jnp.Array) – Random key.

Returns:

num_samples samples from the distribution.

Return type:

jnp.Array

scope: Scope | None = None#
class jaxili.model.MAFLayer(n_in: int, n_cond: int, hidden_dims: list[int], reverse: bool, activation: ~typing.Callable, seed: int | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: Module

Base class for a Masked Autoregressive Flow layer.

A single layer of a Masked Autoregressive Flow.

Parameters:
  • n_in (int) – Size of the input vector.

  • n_cond (int) – Size of the conditionning variable.

  • hidden_dims (list[int]) – List of hidden dimensions.

  • reverse (bool) – Whether to reverse the order of the input.

  • activation (Callable) – Activation function.

  • seed (Optional[int]) – Random seed to label nodes. !!Default is None but the MAF will not work unless a seed is applied!!

Attributes:
name
parent
path

Get the path of this Module.

scope
seed
variables

Returns the variables in this module.

Methods

__call__(x[, y])

Forward pass of the model.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

backward(u[, y])

Backward pass of the model.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

forward(x[, y])

Forward pass of the model.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

activation: Callable#
backward(u, y=None)[source]#

Backward pass of the model.

Return vector x transformed by the inverse flow and the log-determinant of the Jacobian of the inverse flow.

Parameters:
  • u (jnp.Array) – Input vector.

  • y (jnp.Array) – Conditionning variable.

Returns:

  • jnp.Array – Transformed vector.

  • jnp.Array – Log-determinant of the Jacobian.

forward(x, y=None)[source]#

Forward pass of the model.

Return vector u transformed by the flow and the log-determinant of the Jacobian of the flow.

Parameters:
  • x (jnp.Array) – Input vector.

  • y (jnp.Array) – Conditionning variable.

Returns:

  • jnp.Array – Transformed vector.

  • jnp.Array – Log-determinant of the Jacobian.

hidden_dims: list[int]#
n_cond: int#
n_in: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
reverse: bool#
scope: Scope | None = None#
seed: int | None = None#
class jaxili.model.MaskedLinear(n_out: int, bias: bool = True, mask: ~typing.Any = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: Module

Base class for a Masked Linear layer.

Linear transformation with masked out elements.

y = x.dot(mask*W.T)+b

Parameters:
  • n_out (int) – Output dimension.

  • bias (bool) – Whether to include bias. Default True.

  • mask (Any) – Mask to apply to the weights. Default None.

Attributes:
mask
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

__call__(x)

Apply masked linear transformation.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

initialize_mask(mask)

Set initialize mask.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

bias: bool = True#
initialize_mask(mask: Any)[source]#

Set initialize mask.

Parameters:

mask (Any) – Boolean mask to apply to the weights.

mask: Any = None#
n_out: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class jaxili.model.MixtureDensityNetwork(n_in: int, n_cond: int, n_components: int, layers: list[int], activation: ~typing.Callable, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: NDENetwork

Base class for a Mixture Density Network.

A Mixture of Gaussian Density modeled using neural networks. The weights of each gaussian component, the mean and the covariance are learned by the network.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

__call__(y, **kwargs)

Build a bijector that tranforms a multivariate Gaussian distribution into a Mixture of Gaussian distribution using a neural network.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x, y, **kwargs)

Return the log probability of the data point x conditioned by y.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample(y, num_samples, key, **kwargs)

Sample from the distribution conditioned by y.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

activation: Callable#
layers: list[int]#
log_prob(x, y, **kwargs)[source]#

Return the log probability of the data point x conditioned by y.

Parameters:
  • x (jnp.Array) – Data point.

  • y (jnp.Array) – Conditionning variable.

Returns:

Log probability of the data point.

Return type:

jnp.Array

n_components: int#
n_cond: int#
n_in: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
sample(y, num_samples, key, **kwargs)[source]#

Sample from the distribution conditioned by y.

Parameters:
  • y (jnp.Array) – Conditionning variable.

  • num_samples (int) – Number of samples.

  • key (jnp.Array) – Random key.

Returns:

num_samples samples from the distribution

Return type:

jnp.Array

scope: Scope | None = None#
class jaxili.model.NDENetwork(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: Module

Base class for a Normalizing Flow.

A Normalizing Flow parent class to implement normalizing flows using neural networks.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x[, y])

Log probability of the data point x conditioned by y.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample(y, num_samples, key)

Sample from the distribution conditioned by y.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

abstract log_prob(x, y=None, **kwargs)[source]#

Log probability of the data point x conditioned by y.

Parameters:
  • x (jnp.Array) – Data point.

  • y (jnp.Array) – Conditionning variable.

Returns:

Log probability of the data point given y.

Return type:

jnp.Array

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
abstract sample(y, num_samples, key)[source]#

Sample from the distribution conditioned by y.

Parameters:
  • y (jnp.Array) – Conditionning variable.

  • num_samples (int) – Number of samples.

  • key (jnp.Array) – Random key.

Returns:

num_samples samples from the distribution.

Return type:

jnp.Array

scope: Scope | None = None#
class jaxili.model.NDE_Compressor(compressor: ~flax.linen.module.Module, nde: ~jaxili.model.NDENetwork, compressor_hparams: dict, nde_hparams: dict, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: Compressor_w_NDE

Base class for a normalizing flow with a compressor.

WARNING: This class will likely be removed in the future as it is obsolete. A general class to implement a compressor followed by a normalizing flow implementing standard methods to compute the log-probability of the target distribution or sample from it.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

__call__(x, y[, model])

Perform a forward pass in the network and returns the log-probability of x given y.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

compress(x)

Compress the data point x using the compressor.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x, y[, model])

Return the log-probability of the parameters y conditioned by the data point x.

log_prob_compressed(z, y[, model])

Return the log-probability of the compressed data z conditioned by the parameters y (if NPE).

log_prob_from_compressed(z[, y])

Log probability of the data point z conditioned by y.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample(y, num_samples, key[, model])

Sample from the distribution conditioned by y.

setup()

Set the compressor and the normalizing flow.

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

compressor: Module#
compressor_hparams: dict#
log_prob(x, y, model='NPE')[source]#

Return the log-probability of the parameters y conditioned by the data point x.

Parameters:
  • x (jnp.Array) – Data point

  • y (jnp.Array) – Conditionning variable

Returns:

Log probability of the parameters y

Return type:

jnp.Array

log_prob_compressed(z, y, model='NPE')[source]#

Return the log-probability of the compressed data z conditioned by the parameters y (if NPE).

Parameters:
  • z (jnp.Array) – Compressed data point

  • y (jnp.Array) – Conditionning variable

Returns:

Log probability of the parameters y

Return type:

jnp.Array

name: str | None = None#
nde: NDENetwork#
nde_hparams: dict#
parent: Module | Scope | _Sentinel | None = None#
sample(y, num_samples, key, model='NPE')[source]#

Sample from the distribution conditioned by y.

Parameters:
  • y (jnp.Array) – Conditionning variable

  • num_samples (int) – Number of samples

  • key (jnp.Array) – Random key

Returns:

num_samples samples from the distribution

Return type:

jnp.Array

scope: Scope | None = None#
setup()[source]#

Set the compressor and the normalizing flow.

class jaxili.model.NDE_w_Standardization(nde: ~jaxili.model.NDENetwork, embedding_net: ~flax.linen.module.Module, transformation: ~distrax._src.bijectors.bijector.Bijector, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#

Bases: NDENetwork

Base class to implement normalizing flow with a standardization step.

This class creates an NDE network where the input data is first standardized. It takes in input a neural density estimator, an embedding net and a transformation. The embedding net is used to embed the data point in a latent space where the NDE is applied. It allows to compress the data to lower dimensional space. The transformation is used to transform to standardize the variable learned by the normalizing flow for stability purpose.

Attributes:
name
parent
path

Get the path of this Module.

scope
variables

Returns the variables in this module.

Methods

__call__(x, y[, model])

Return the log-probability of x given y for NPE and y given x for NLE.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

embedding(x)

Embed the data point x.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

log_prob(x[, y, model])

Return the log probability of the data point x conditioned by y.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

module_paths(rngs, *args[, show_repeated, ...])

Returns a dictionary mapping module paths to module instances.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

sample(y, num_samples, key[, model])

Sample from the distribution conditioned by y.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

standardize(x)

Standardize the data point x.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

unstandardize(x)

Unstandardize the data point x.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

embedding(x)[source]#

Embed the data point x.

embedding_net: Module#
log_prob(x, y=None, model='NPE')[source]#

Return the log probability of the data point x conditioned by y.

name: str | None = None#
nde: NDENetwork#
parent: Module | Scope | _Sentinel | None = None#
sample(y, num_samples, key, model='NPE')[source]#

Sample from the distribution conditioned by y.

scope: Scope | None = None#
standardize(x)[source]#

Standardize the data point x.

transformation: Bijector#
unstandardize(x)[source]#

Unstandardize the data point x.