jaxili.loss module#
Loss.
This module contains useful loss functions used in the neural network training.
- jaxili.loss.gaussian_kernel_matrix(x, y, sigmas=None)[source]#
Compute a Gaussian radial basis functions (RBFs) between the samples of x and y.
We create a sum of multiple Gaussian kernels each having a width sigma_i.
- Parameters:
x (array of shape (num_draws_x, num_features))
y (array of shape (num_draws_y, num_features))
sigmas (list(float), optional, default: None) – List which denotes the width of each of the Gaussian in the kernel. A default range is used if sigmas is None.
- Returns:
kernel values
- Return type:
array of shape (num_draws_x, num_draws_y)
- jaxili.loss.loss_mmd_npe(model, params, batch)[source]#
Compute the Maximum Mean Discrepancy (MMD) loss for Neural Posterior Estimation.
- Parameters:
compress (function) – Neural network function to compress the data.
nf (function) – Neural network function to compute the log-probability conditionally to a random variable.
params (jnp.array) – Parameters of the neural network.
batch (jnp.array) – Batch of data.
- Returns:
Maximum Mean Discrepancy (MMD) loss.
- Return type:
float
- jaxili.loss.loss_nll_nle(model: Any, params: PyTree, batch: Any)[source]#
Negative log-likelihood loss function for NLE methods using a given neural network as a model.
In NLE, the log-probability is given by the density estimator in observation space conditioned on the parameter. The batch should take the form of a Tuple of Arrays were the first one corresponds to the parameters and the second to the simulation outputs.
- Parameters:
model (Any) – Neural network model from jaxili.model.
params (PyTree) – Parameters of the neural network.
batch (Any) – Batch of (parameters, outputs) to compute the loss.
- Returns:
Mean of the negative log-likelihood loss across the batch.
- Return type:
Array
- jaxili.loss.loss_nll_npe(model: Any, params: PyTree, batch: Any) Array[source]#
Negative log-likelihood loss function for NPE methods using a given neural network as a model.
In NPE, the log-probability is given by the density estimator in parameter space conditioned on the data. The batch should take the form of a Tuple of Arrays were the first one corresponds to the parameters and the second to the simulation outputs.
- Parameters:
model (Any) – Neural network model from jaxili.model.
params (PyTree) – Parameters of the neural network.
batch (Any) – Batch of (parameters, outputs) to compute the loss.
- Returns:
Mean of the negative log-likelihood loss across the batch.
- Return type:
Array
- jaxili.loss.maximum_mean_discrepancy(source_samples, target_samples, kernel='gaussian', mmd_weight=1.0, minimum=0.0)[source]#
Compute the Maximum Mean Discrepancy (MMD) between source and target samples.
- Parameters:
source_samples (samples from the source distribution. Shape: (N, num_features))
target_samples (samples from the target distribution. Shape: (M, num_features))
kernel (kernel function to use for the MMD computation. str: "gaussian")
mmd_weight (weight for the MMD loss. Default: 1.0)
minimum (minimum value for the MMD loss. Default: 0.0)
- Returns:
Maximum Mean Discrepancy (MMD) between source and target samples.
- Return type:
float
- jaxili.loss.mmd_kernel(x, y, kernel)[source]#
Compute the Maximum Mean Discrepancy (MMD) between samples of x and y.
- Parameters:
x (array of shape (num_draws_x, num_features))
y (array of shape (num_draws_y, num_features))
kernel (function) – A kernel function which computes the similarity between two sets of samples.
- Returns:
Maximum Mean Discrepancy (MMD) between x and y.
- Return type:
float
- jaxili.loss.mmd_summary_space(summary_outputs, rng, z_dist='gaussian', kernel='gaussian')[source]#
Compute the Maximum Mean Discrepancy (MMD) between the summary outputs and samples from a unit Gaussian distribution.
- Parameters:
summary_outputs (array of shape (num_samples, num_features)) – Summary outputs from the neural network.
rng (jax.random.PRNGKey) – Random key for reproducibility.
z_dist (str, optional) – Distribution of the samples. Default: “gaussian”
kernel (str, optional) – Kernel function to use for the MMD computation. Default: “gaussian”
- Returns:
Maximum Mean Discrepancy (MMD) between the summary outputs and samples from a unit Gaussian distribution.
- Return type:
float