jaxili.model module#
Model.
This module contains classes to implement normalizing flows using neural networks.
- class jaxili.model.AffineCoupling(y: ~typing.Any, layers: list, activation: callable, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
ModuleBase class for an Affine Coupling layer for RealNVP.
- Parameters:
y (Any) – Conditionning variable.
layers (list) – List of hidden layers size.
activation (Callable) – Activation function.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
__call__(x, output_units, **kwargs)Build the bijector using tensorflow_probability where the scale and the shift are learned by a neural network.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- activation: callable#
- layers: list#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- y: Any#
- class jaxili.model.Compressor_w_NDE(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
NDENetworkBase class to create a normalizing flow with a compression of the conditionning variable.
A parent class to implement a compressor followed by a normalizing flow. This is useful to perform Implicit Likelihood Inference in large dimensions where compression is required and can sometimes be done with a normalizing flow.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
compress(x)Compress the data point x using the compressor.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x[, y])Log probability of the data point x conditioned by y.
log_prob_from_compressed(z[, y])Log probability of the data point z conditioned by y.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample(y, num_samples, key)Sample from the distribution conditioned by y.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- abstract compress(x)[source]#
Compress the data point x using the compressor.
- Parameters:
x (jnp.Array) – Data point.
- Returns:
Compressed data point.
- Return type:
jnp.Array
- abstract log_prob(x, y=None, **kwargs)[source]#
Log probability of the data point x conditioned by y.
- Parameters:
x (jnp.Array) – Data point.
y (jnp.Array) – Conditionning variable.
- Returns:
Log probability of the data point conditioned by y.
- Return type:
jnp.Array
- abstract log_prob_from_compressed(z, y=None, **kwargs)[source]#
Log probability of the data point z conditioned by y. z has been previously compressed.
- Parameters:
z (jnp.Array) – Compressed data point.
y (jnp.Array) – Conditionning variable.
- Returns:
Log probability of the data point conditioned by y.
- Return type:
jnp.Array
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- abstract sample(y, num_samples, key)[source]#
Sample from the distribution conditioned by y.
- Parameters:
y (jnp.Array) – Conditionning variable.
num_samples (int) – Number of samples.
key (jnp.Array) – Random key.
- Returns:
num_samples samples from the distribution.
- Return type:
jnp.Array
- scope: Scope | None = None#
- class jaxili.model.ConditionalMADE(n_in: int, hidden_dims: list[int], activation: ~typing.Callable, n_cond: int = 0, gaussian: bool = True, random_order: bool = False, seed: int | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
ModuleBase class for Conditional Masked Autoencoder Density Estimatior (MADE).
MADE is a neural network that parameterizes the conditional distribution of a random variable using masked linear layers.
- Parameters:
n_in (int) – Size of the input vector.
hidden_dims (list[int]) – List of hidden dimensions.
activation (Callable) – Activation function.
n_cond (int) – Size of the conditionning variable. 0 if None.
gaussian (bool) – Whether the output are mean and variance of a Gaussian conditional. Default True.
random_order (bool) – Whether to use random order of the input for masking. Default False.
seed (Optional[int]) – Random seed to label nodes. !!Default is None but the MADE will not work unless a seed is applied!!
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
- seed
variablesReturns the variables in this module.
Methods
__call__(x[, y])Forward pass of the model.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Set the network creating the masks and the masked linear layers.
sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- activation: Callable#
- gaussian: bool = True#
- n_cond: int = 0#
- n_in: int#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- random_order: bool = False#
- scope: Scope | None = None#
- seed: int | None = None#
- class jaxili.model.ConditionalMAF(n_in: int, n_cond: int, n_layers: int, layers: list[int], activation: ~typing.Callable, use_reverse: bool, seed: int | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
NDENetworkBase class of a Conditional Masked Autoregressive Flow.
A Conditional Masked Autoregressive Flow to model the conditional distribution of a random variable. It is obtained by stacking n_layers MAF layers.
- Parameters:
n_in (int) – Size of the input vector.
n_cond (int) – Size of the conditionning variable.
n_layers (int) – Number of layers (i.e. number of stacked MAFs).
layers (list[int]) – List of hidden dimensions in each MAF.
activation (Callable) – Activation function.
use_reverse (bool) – Whether to reverse the order of the input between each MAF.
seed (Optional[int]) – Random seed to label nodes. !!Default is None but the MAF will not work unless a seed is applied!!
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
- seed
variablesReturns the variables in this module.
Methods
__call__(x[, y])Forward pass of the model.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
backward(u[, y])Backward pass of the model.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x[, y])Compute the log-probability conditionned on some conditionning variable.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample([y, num_samples, key])Sample from the distribution emulated by the neural network.
setup()Set the network creating the MAF layers.
sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- activation: Callable#
- backward(u, y=None)[source]#
Backward pass of the model.
Return vector x transformed by the inverse flow and the log-determinant of the Jacobian of the inverse flow.
- Parameters:
u (jnp.Array) – Input vector.
y (jnp.Array) – Conditionning variable.
- Returns:
x (jnp.Array) – Transformed vector.
log_det_sum (jnp.Array) – Log-determinant of the Jacobian.
- layers: list[int]#
- log_prob(x, y=None)[source]#
Compute the log-probability conditionned on some conditionning variable.
- Parameters:
x (jnp.Array) – Input vector.
y (jnp.Array) – Conditionning variable.
- Returns:
Log probability of the data point.
- Return type:
jnp.Array
- n_cond: int#
- n_in: int#
- n_layers: int#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- sample(y=None, num_samples=1, key=None)[source]#
Sample from the distribution emulated by the neural network.
- Parameters:
y (jnp.Array) – Conditionning variable.
num_samples (int) – Number of samples.
key (jnp.Array) – Random key.
- Returns:
Samples from the distribution.
- Return type:
jnp.Array
- scope: Scope | None = None#
- seed: int | None = None#
- use_reverse: bool#
- class jaxili.model.ConditionalRealNVP(n_in: int, n_cond: int, n_layers: int, layers: list[int], activation: ~typing.Callable, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
NDENetworkBase class for a Conditional RealNVP.
A Normalizing Flow using RealNVP with a conditionning variable.
- Parameters:
n_in (int) – Dimension of the input.
n_cond (int) – Dimension of the conditionning variable.
n_layers (int) – Number of layers.
layers (list[int]) – List of hidden layers size.
activation (Callable) – Activation function.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
__call__(y, **kwargs)Build the bijector using tensorflow_probability.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x, y, **kwargs)Compute the log probability of the data point x conditioned by y from the normalizing flow.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample(y, num_samples, key, **kwargs)Sample from the distribution mapped by the real NVP.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- activation: Callable#
- layers: list[int]#
- log_prob(x, y, **kwargs)[source]#
Compute the log probability of the data point x conditioned by y from the normalizing flow.
- Parameters:
x (jnp.Array) – Data point.
y (jnp.Array) – Conditionning variable.
- Returns:
Log probability of the data point conditioned by y.
- Return type:
jnp.Array
- n_cond: int#
- n_in: int#
- n_layers: int#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- sample(y, num_samples, key, **kwargs)[source]#
Sample from the distribution mapped by the real NVP.
- Parameters:
y (jnp.Array) – Conditionning variable.
num_samples (int) – Number of samples.
key (jnp.Array) – Random key.
- Returns:
num_samples samples from the distribution.
- Return type:
jnp.Array
- scope: Scope | None = None#
- class jaxili.model.MAFLayer(n_in: int, n_cond: int, hidden_dims: list[int], reverse: bool, activation: ~typing.Callable, seed: int | None = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
ModuleBase class for a Masked Autoregressive Flow layer.
A single layer of a Masked Autoregressive Flow.
- Parameters:
n_in (int) – Size of the input vector.
n_cond (int) – Size of the conditionning variable.
hidden_dims (list[int]) – List of hidden dimensions.
reverse (bool) – Whether to reverse the order of the input.
activation (Callable) – Activation function.
seed (Optional[int]) – Random seed to label nodes. !!Default is None but the MAF will not work unless a seed is applied!!
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
- seed
variablesReturns the variables in this module.
Methods
__call__(x[, y])Forward pass of the model.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
backward(u[, y])Backward pass of the model.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
forward(x[, y])Forward pass of the model.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- activation: Callable#
- backward(u, y=None)[source]#
Backward pass of the model.
Return vector x transformed by the inverse flow and the log-determinant of the Jacobian of the inverse flow.
- Parameters:
u (jnp.Array) – Input vector.
y (jnp.Array) – Conditionning variable.
- Returns:
jnp.Array – Transformed vector.
jnp.Array – Log-determinant of the Jacobian.
- forward(x, y=None)[source]#
Forward pass of the model.
Return vector u transformed by the flow and the log-determinant of the Jacobian of the flow.
- Parameters:
x (jnp.Array) – Input vector.
y (jnp.Array) – Conditionning variable.
- Returns:
jnp.Array – Transformed vector.
jnp.Array – Log-determinant of the Jacobian.
- n_cond: int#
- n_in: int#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- reverse: bool#
- scope: Scope | None = None#
- seed: int | None = None#
- class jaxili.model.MaskedLinear(n_out: int, bias: bool = True, mask: ~typing.Any = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
ModuleBase class for a Masked Linear layer.
Linear transformation with masked out elements.
y = x.dot(mask*W.T)+b
- Parameters:
n_out (int) – Output dimension.
bias (bool) – Whether to include bias. Default True.
mask (Any) – Mask to apply to the weights. Default None.
- Attributes:
- mask
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
__call__(x)Apply masked linear transformation.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
initialize_mask(mask)Set initialize mask.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- bias: bool = True#
- initialize_mask(mask: Any)[source]#
Set initialize mask.
- Parameters:
mask (Any) – Boolean mask to apply to the weights.
- mask: Any = None#
- n_out: int#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class jaxili.model.MixtureDensityNetwork(n_in: int, n_cond: int, n_components: int, layers: list[int], activation: ~typing.Callable, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
NDENetworkBase class for a Mixture Density Network.
A Mixture of Gaussian Density modeled using neural networks. The weights of each gaussian component, the mean and the covariance are learned by the network.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
__call__(y, **kwargs)Build a bijector that tranforms a multivariate Gaussian distribution into a Mixture of Gaussian distribution using a neural network.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x, y, **kwargs)Return the log probability of the data point x conditioned by y.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample(y, num_samples, key, **kwargs)Sample from the distribution conditioned by y.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- activation: Callable#
- layers: list[int]#
- log_prob(x, y, **kwargs)[source]#
Return the log probability of the data point x conditioned by y.
- Parameters:
x (jnp.Array) – Data point.
y (jnp.Array) – Conditionning variable.
- Returns:
Log probability of the data point.
- Return type:
jnp.Array
- n_components: int#
- n_cond: int#
- n_in: int#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- sample(y, num_samples, key, **kwargs)[source]#
Sample from the distribution conditioned by y.
- Parameters:
y (jnp.Array) – Conditionning variable.
num_samples (int) – Number of samples.
key (jnp.Array) – Random key.
- Returns:
num_samples samples from the distribution
- Return type:
jnp.Array
- scope: Scope | None = None#
- class jaxili.model.NDENetwork(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
ModuleBase class for a Normalizing Flow.
A Normalizing Flow parent class to implement normalizing flows using neural networks.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x[, y])Log probability of the data point x conditioned by y.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample(y, num_samples, key)Sample from the distribution conditioned by y.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- abstract log_prob(x, y=None, **kwargs)[source]#
Log probability of the data point x conditioned by y.
- Parameters:
x (jnp.Array) – Data point.
y (jnp.Array) – Conditionning variable.
- Returns:
Log probability of the data point given y.
- Return type:
jnp.Array
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- abstract sample(y, num_samples, key)[source]#
Sample from the distribution conditioned by y.
- Parameters:
y (jnp.Array) – Conditionning variable.
num_samples (int) – Number of samples.
key (jnp.Array) – Random key.
- Returns:
num_samples samples from the distribution.
- Return type:
jnp.Array
- scope: Scope | None = None#
- class jaxili.model.NDE_Compressor(compressor: ~flax.linen.module.Module, nde: ~jaxili.model.NDENetwork, compressor_hparams: dict, nde_hparams: dict, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
Compressor_w_NDEBase class for a normalizing flow with a compressor.
WARNING: This class will likely be removed in the future as it is obsolete. A general class to implement a compressor followed by a normalizing flow implementing standard methods to compute the log-probability of the target distribution or sample from it.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
__call__(x, y[, model])Perform a forward pass in the network and returns the log-probability of x given y.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
compress(x)Compress the data point x using the compressor.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x, y[, model])Return the log-probability of the parameters y conditioned by the data point x.
log_prob_compressed(z, y[, model])Return the log-probability of the compressed data z conditioned by the parameters y (if NPE).
log_prob_from_compressed(z[, y])Log probability of the data point z conditioned by y.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample(y, num_samples, key[, model])Sample from the distribution conditioned by y.
setup()Set the compressor and the normalizing flow.
sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- compressor: Module#
- compressor_hparams: dict#
- log_prob(x, y, model='NPE')[source]#
Return the log-probability of the parameters y conditioned by the data point x.
- Parameters:
x (jnp.Array) – Data point
y (jnp.Array) – Conditionning variable
- Returns:
Log probability of the parameters y
- Return type:
jnp.Array
- log_prob_compressed(z, y, model='NPE')[source]#
Return the log-probability of the compressed data z conditioned by the parameters y (if NPE).
- Parameters:
z (jnp.Array) – Compressed data point
y (jnp.Array) – Conditionning variable
- Returns:
Log probability of the parameters y
- Return type:
jnp.Array
- name: str | None = None#
- nde: NDENetwork#
- nde_hparams: dict#
- parent: Module | Scope | _Sentinel | None = None#
- sample(y, num_samples, key, model='NPE')[source]#
Sample from the distribution conditioned by y.
- Parameters:
y (jnp.Array) – Conditionning variable
num_samples (int) – Number of samples
key (jnp.Array) – Random key
- Returns:
num_samples samples from the distribution
- Return type:
jnp.Array
- scope: Scope | None = None#
- class jaxili.model.NDE_w_Standardization(nde: ~jaxili.model.NDENetwork, embedding_net: ~flax.linen.module.Module, transformation: ~distrax._src.bijectors.bijector.Bijector, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]#
Bases:
NDENetworkBase class to implement normalizing flow with a standardization step.
This class creates an NDE network where the input data is first standardized. It takes in input a neural density estimator, an embedding net and a transformation. The embedding net is used to embed the data point in a latent space where the NDE is applied. It allows to compress the data to lower dimensional space. The transformation is used to transform to standardize the variable learned by the normalizing flow for stability purpose.
- Attributes:
- name
- parent
pathGet the path of this Module.
- scope
variablesReturns the variables in this module.
Methods
__call__(x, y[, model])Return the log-probability of x given y for NPE and y given x for NLE.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
embedding(x)Embed the data point x.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
log_prob(x[, y, model])Return the log probability of the data point x conditioned by y.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
module_paths(rngs, *args[, show_repeated, ...])Returns a dictionary mapping module paths to module instances.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
sample(y, num_samples, key[, model])Sample from the distribution conditioned by y.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
standardize(x)Standardize the data point x.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
Unstandardize the data point x.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
- embedding_net: Module#
- log_prob(x, y=None, model='NPE')[source]#
Return the log probability of the data point x conditioned by y.
- name: str | None = None#
- nde: NDENetwork#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- transformation: Bijector#