Source code for jaxili.model

"""Model.

This module contains classes to implement normalizing flows using neural networks.

"""

from abc import abstractmethod
from functools import partial
from typing import Any, Callable, Optional

import distrax
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability as tfp
from flax import linen as nn
from jax.scipy.stats import multivariate_normal
from jaxtyping import Array

tfp = tfp.substrates.jax
tfb = tfp.bijectors
tfd = tfp.distributions


[docs] class NDENetwork(nn.Module): """ Base class for a Normalizing Flow. A Normalizing Flow parent class to implement normalizing flows using neural networks. """
[docs] @abstractmethod def log_prob(self, x, y=None, **kwargs): """ Log probability of the data point x conditioned by y. Parameters ---------- x : jnp.Array Data point. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Log probability of the data point given y. """ raise NotImplementedError( "log_prob method not implemented in your child class of NDENetwork" )
[docs] @abstractmethod def sample(self, y, num_samples, key): """ Sample from the distribution conditioned by y. Parameters ---------- y : jnp.Array Conditionning variable. num_samples : int Number of samples. key : jnp.Array Random key. Returns ------- jnp.Array num_samples samples from the distribution. """ raise NotImplementedError( "sample method not implemented in tour child class of NDENetwork" )
[docs] class Compressor_w_NDE(NDENetwork): """ Base class to create a normalizing flow with a compression of the conditionning variable. A parent class to implement a compressor followed by a normalizing flow. This is useful to perform Implicit Likelihood Inference in large dimensions where compression is required and can sometimes be done with a normalizing flow. """
[docs] @abstractmethod def compress(self, x): """ Compress the data point x using the compressor. Parameters ---------- x : jnp.Array Data point. Returns ------- jnp.Array Compressed data point. """ raise NotImplementedError( "compress method not implemented in your child class of Compressor_w_NDE" )
[docs] @abstractmethod def log_prob(self, x, y=None, **kwargs): """ Log probability of the data point x conditioned by y. Parameters ---------- x : jnp.Array Data point. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Log probability of the data point conditioned by y. """ raise NotImplementedError( "log_prob method not implemented in your child class of Compressor_w_NDE" )
[docs] @abstractmethod def log_prob_from_compressed(self, z, y=None, **kwargs): """ Log probability of the data point z conditioned by y. z has been previously compressed. Parameters ---------- z : jnp.Array Compressed data point. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Log probability of the data point conditioned by y. """ raise NotImplementedError( "log_prob_from_compressed method not implemented in your child class of Compressor_w_NDE" )
[docs] @abstractmethod def sample(self, y, num_samples, key): """ Sample from the distribution conditioned by y. Parameters ---------- y : jnp.Array Conditionning variable. num_samples : int Number of samples. key : jnp.Array Random key. Returns ------- jnp.Array num_samples samples from the distribution. """ raise NotImplementedError( "sample method not implemented in your child class of Compressor_w_NDE" )
[docs] class MixtureDensityNetwork(NDENetwork): """ Base class for a Mixture Density Network. A Mixture of Gaussian Density modeled using neural networks. The weights of each gaussian component, the mean and the covariance are learned by the network. """ n_in: int # Dimension of the input n_cond: int # Dimension of conditionning variable n_components: int # number of mixture components layers: list[int] # list of hidden layers size activation: Callable # activation function @nn.compact def __call__(self, y, **kwargs): """ Build a bijector that tranforms a multivariate Gaussian distribution into a Mixture of Gaussian distribution using a neural network. The weights, means and covariances are obtained from a conditioned variable y. Parameters ---------- y : jnp.Array Conditionning variable. Returns ------- tfd.Distribution Mixture of Gaussian distribution. """ kernel_init = kwargs.get( "kernel_init", nn.initializers.variance_scaling( scale=1.0, mode="fan_in", distribution="normal" ), ) for size in self.layers: y = self.activation(nn.Dense(size, kernel_init=kernel_init)(y)) final_size = self.n_components * ( 1 + self.n_in + self.n_in * (self.n_in + 1) // 2 ) y = nn.Dense(final_size, kernel_init=kernel_init)(y) logits = jax.nn.log_softmax(y[..., : self.n_components]) locs = y[..., self.n_components : self.n_components * (self.n_in + 1)] scale_tril = y[..., self.n_components * (self.n_in + 1) :] distribution = distrax.MixtureSameFamily( mixture_distribution=distrax.Categorical(logits=logits), components_distribution=tfd.MultivariateNormalTriL( loc=jnp.reshape(locs, (-1, self.n_components, self.n_in)), scale_tril=tfp.math.fill_triangular( jnp.reshape( scale_tril, (-1, self.n_components, self.n_in * (self.n_in + 1) // 2), ) ), ), ) return distribution
[docs] def log_prob(self, x, y, **kwargs): """ Return the log probability of the data point x conditioned by y. Parameters ---------- x : jnp.Array Data point. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Log probability of the data point. """ distribution = self.__call__(y, **kwargs) return distribution.log_prob(x)
[docs] def sample(self, y, num_samples, key, **kwargs): """ Sample from the distribution conditioned by y. Parameters ---------- y : jnp.Array Conditionning variable. num_samples : int Number of samples. key : jnp.Array Random key. Returns ------- jnp.Array num_samples samples from the distribution """ if y.ndim == 1: y = y[None, :] distribution = self.__call__(y, **kwargs) return distribution.sample(sample_shape=num_samples, seed=key).squeeze()
[docs] class AffineCoupling(nn.Module): """ Base class for an Affine Coupling layer for RealNVP. Parameters ---------- y : Any Conditionning variable. layers : list List of hidden layers size. activation : Callable Activation function. """ y: Any # Conditionning variable layers: list # list of hidden layers size activation: callable # activation function @nn.compact def __call__(self, x, output_units, **kwargs): """ Build the bijector using tensorflow_probability where the scale and the shift are learned by a neural network. Parameters ---------- x : jnp.Array Data point. output_units : int Dimension of the output. Returns ------- tfb.Chain Bijector transforming a multidimensional Gaussian to a more complex distribution. """ x = jnp.concatenate([x, self.y], axis=-1) for i, layer_size in enumerate(self.layers): x = self.activation( nn.Dense( layer_size, kernel_init=nn.initializers.truncated_normal(0.001) )(x) ) # Shift and Scale parameters shift = nn.Dense( output_units, kernel_init=nn.initializers.truncated_normal(0.001) )(x) scale = ( nn.softplus( nn.Dense( output_units, kernel_init=nn.initializers.truncated_normal(0.001) )(x) ) + 1e-3 ) return tfb.Chain([tfb.Shift(shift), tfb.Scale(scale)])
[docs] class ConditionalRealNVP(NDENetwork): """ Base class for a Conditional RealNVP. A Normalizing Flow using RealNVP with a conditionning variable. Parameters ---------- n_in : int Dimension of the input. n_cond : int Dimension of the conditionning variable. n_layers : int Number of layers. layers : list[int] List of hidden layers size. activation : Callable Activation function. """ n_in: int # Dimension of the input n_cond: int # Dimension of the conditionning variable n_layers: int # Number of layers layers: list[int] # list of hidden layers size activation: Callable # activation function @nn.compact def __call__(self, y, **kwargs): """ Build the bijector using tensorflow_probability. Parameters ---------- y : jnp.Array Conditionning variable. Returns ------- tfd.Distributions Normalizing Flow transporting a multidimensional Gaussian to a more complex distribution. """ if self.n_in == 1: raise ValueError( "Flows can't be used to learn a one dimensional distribution. Consider using the `MixtureDensityNetwork`." ) bijector_fn = partial( AffineCoupling, layers=self.layers, activation=self.activation ) base_distribution = distrax.MultivariateNormalDiag( jnp.zeros(self.n_in), jnp.ones(self.n_in) ) chain = distrax.Chain( [ tfb.Permute(jnp.arange(self.n_in)[::-1])( tfb.RealNVP( self.n_in // 2, bijector_fn=bijector_fn(y, name="b%d" % i) ) ) for i in range(self.n_layers) ] ) nvp = distrax.Transformed(base_distribution, bijector=chain) return nvp
[docs] def sample(self, y, num_samples, key, **kwargs): """ Sample from the distribution mapped by the real NVP. Parameters ---------- y : jnp.Array Conditionning variable. num_samples : int Number of samples. key : jnp.Array Random key. Returns ------- jnp.Array num_samples samples from the distribution. """ y = y.squeeze() nvp = self.__call__(y) return nvp.sample(sample_shape=num_samples, seed=key)
[docs] def log_prob(self, x, y, **kwargs): """ Compute the log probability of the data point x conditioned by y from the normalizing flow. Parameters ---------- x : jnp.Array Data point. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Log probability of the data point conditioned by y. """ nvp = self.__call__(y) return nvp.log_prob(x)
# Reproduce implementation of MADE and MAFs from https://github.com/e-hulten/maf/blob/master/made.py
[docs] class MaskedLinear(nn.Module): """ Base class for a Masked Linear layer. Linear transformation with masked out elements. y = x.dot(mask*W.T)+b Parameters ---------- n_out : int Output dimension. bias : bool Whether to include bias. Default True. mask : Any Mask to apply to the weights. Default None. """ n_out: int bias: bool = True mask: Any = None
[docs] def initialize_mask(self, mask: Any): """ Set initialize mask. Parameters ---------- mask : Any Boolean mask to apply to the weights. """ self.mask = mask
@nn.compact def __call__(self, x): """ Apply masked linear transformation. Parameters ---------- x : jnp.Array Input vector. Returns ------- jnp.Array Output vector. """ layer = nn.Dense( self.n_out, use_bias=self.bias, param_dtype=jnp.float64, kernel_init=nn.initializers.truncated_normal(0.01), ) is_initialized = self.has_variable("params", "Dense_0") if is_initialized: w = layer.variables["params"]["kernel"] b = layer.variables["params"]["bias"] else: return layer(x) return jnp.dot(x, self.mask * w) + b
[docs] class ConditionalMADE(nn.Module): """ Base class for Conditional Masked Autoencoder Density Estimatior (MADE). MADE is a neural network that parameterizes the conditional distribution of a random variable using masked linear layers. Parameters ---------- n_in : int Size of the input vector. hidden_dims : list[int] List of hidden dimensions. activation : Callable Activation function. n_cond : int Size of the conditionning variable. 0 if None. gaussian : bool Whether the output are mean and variance of a Gaussian conditional. Default True. random_order : bool Whether to use random order of the input for masking. Default False. seed : Optional[int] Random seed to label nodes. !!Default is None but the MADE will not work unless a seed is applied!! """ n_in: int # Size of the input vector hidden_dims: list[int] # list of hidden dimensions activation: Callable # Activation function n_cond: int = 0 # Size of the conditionning variable. 0 if None. gaussian: bool = ( True # whether the output are mean and variance of a Gaussian conditional ) random_order: bool = False # Whether to use random order of the input for masking seed: Optional[int] = None # Random seed to label nodes
[docs] def setup(self): """Set the network creating the masks and the masked linear layers.""" np.random.seed(self.seed) self.n_out = 2 * self.n_in if self.gaussian else self.n_in masks = {} mask_matrix = [] layers = [] dim_list = [self.n_in + self.n_cond, *self.hidden_dims, self.n_out] # Make layers and activation functions for i in range(len(dim_list) - 2): layers.append(MaskedLinear(dim_list[i + 1])) layers.append(self.activation) # Last hidden layer to output layer layers.append(MaskedLinear(dim_list[-1])) # Create masks self._create_masks(mask_matrix, masks, layers) # Create model self.layers = layers self.model = nn.Sequential(self.layers)
def _create_masks(self, mask_matrix: list, masks: dict, layers: list): """Create masks for the model.""" L = len(self.hidden_dims) # Number of hidden layers D = self.n_in # Number of input parameters C = self.n_cond # Number of conditionning parameters # Whather to use random or natural order of the input masks[0] = np.random.permutation(D) if self.random_order else np.arange(D) # Set the connectivity number for the hidden layers for l in range(L): low = masks[l].min() # Get the lowest index in the previous layer size = self.hidden_dims[l] # The size of the current hidden layer if D > 1: masks[l + 1] = np.random.randint(low, D - 1, size=size) else: masks[l + 1] = np.zeros(size) # Order of the output layer is the same as the input layer masks[L + 1] = masks[0] # Create mask matric for input -> hidden_layer_1 m = masks[0] m_next = masks[1] M = np.ones((len(m), len(m_next))) for j in range(len(m_next)): M[:, j] = (m <= m_next[j]).astype(int) M_cond = np.ones((C, len(m_next))) M = np.concatenate([M, M_cond], axis=0) mask_matrix.append(jnp.array(M)) # Create mask matrix for hidden_layer_1 -> ... -> last_hidden_layers for i in range(1, len(masks) - 2): m = masks[i] m_next = masks[i + 1] # Initialise mask matrix M = np.zeros((len(m), len(m_next))) for j in range(len(m_next)): # Compare m_next[j] to each element of m M[:, j] = (m <= m_next[j]).astype(int) # append matrix to mask list mask_matrix.append(jnp.array(M)) # Create mask matrix for last_hidden_layer -> output m = masks[len(masks) - 2] m_next = masks[len(masks) - 1] M = np.zeros((len(m), len(m_next))) for j in range(len(m)): # Compare m_next[j] to each element of m M[j, :] = (m[j] < m_next).astype(int) # append matrix to mask list mask_matrix.append(jnp.array(M)) # If the output is Gaussian, double the number of output (mu, sigma) # Pairwise identical mask if self.gaussian: m = mask_matrix.pop(-1) mask_matrix.append(jnp.concatenate([m, m], axis=1)) # Initialize the MaskedLinear layers with weights mask_iter = iter(mask_matrix) for module in layers: if isinstance(module, MaskedLinear): module.initialize_mask(next(mask_iter)) def __call__(self, x, y=None): """ Forward pass of the model. Parameters ---------- x : jnp.Array Input vector. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Output vector. If gaussian, the output is the mean and variance of the gaussian conditional. Otherwise, the output is the probability of the binary conditional. """ if self.n_cond != 0: x = jnp.concatenate([x, y], axis=-1) if self.gaussian: return self.model(x) else: return jax.nn.sigmoid(self.model(x))
[docs] class MAFLayer(nn.Module): """ Base class for a Masked Autoregressive Flow layer. A single layer of a Masked Autoregressive Flow. Parameters ---------- n_in : int Size of the input vector. n_cond : int Size of the conditionning variable. hidden_dims : list[int] List of hidden dimensions. reverse : bool Whether to reverse the order of the input. activation : Callable Activation function. seed : Optional[int] Random seed to label nodes. !!Default is None but the MAF will not work unless a seed is applied!! """ n_in: int # Size of the input vector n_cond: int # Size of the conditionning variable hidden_dims: list[int] # list of hidden dimensions reverse: bool # Whether to reverse the order of the input activation: Callable # Activation function seed: Optional[int] = None # Random seed to label nodes
[docs] def forward(self, x, y=None): """ Forward pass of the model. Return vector u transformed by the flow and the log-determinant of the Jacobian of the flow. Parameters ---------- x : jnp.Array Input vector. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Transformed vector. jnp.Array Log-determinant of the Jacobian. """ out = self.__call__(x, y) mu, logp = jnp.split(out, 2, axis=-1) u = (x - mu) * jnp.exp(0.5 * logp) u = jnp.flip(u, axis=-1) if self.reverse else u log_det = 0.5 * jnp.sum(logp, axis=-1) return u, log_det
[docs] def backward(self, u, y=None): """ Backward pass of the model. Return vector x transformed by the inverse flow and the log-determinant of the Jacobian of the inverse flow. Parameters ---------- u : jnp.Array Input vector. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Transformed vector. jnp.Array Log-determinant of the Jacobian. """ u = jnp.flip(u, axis=-1) if self.reverse else u x = jnp.zeros_like(u) for dim in range(self.n_in): out = self.__call__(x, y) mu, logp = jnp.split(out, 2, axis=-1) mod_logp = jax.lax.clamp(-jnp.inf, -0.5 * logp, max=10.0) x = x.at[:, dim].set(mu[:, dim] + jnp.exp(mod_logp[:, dim]) * u[:, dim]) log_det = jnp.sum(mod_logp, axis=-1) return x, log_det
@nn.compact def __call__(self, x, y=None): """ Forward pass of the model. Returns mean and variance of the gaussian conditionals. Parameters ---------- x : jnp.Array Input vector. y : jnp.Array Conditionning variable. """ x = ConditionalMADE( n_in=self.n_in, hidden_dims=self.hidden_dims, n_cond=self.n_cond, seed=self.seed, activation=self.activation, )(x, y) return x
[docs] class ConditionalMAF(NDENetwork): """ Base class of a Conditional Masked Autoregressive Flow. A Conditional Masked Autoregressive Flow to model the conditional distribution of a random variable. It is obtained by stacking `n_layers` MAF layers. Parameters ---------- n_in : int Size of the input vector. n_cond : int Size of the conditionning variable. n_layers : int Number of layers (i.e. number of stacked MAFs). layers : list[int] List of hidden dimensions in each MAF. activation : Callable Activation function. use_reverse : bool Whether to reverse the order of the input between each MAF. seed : Optional[int] Random seed to label nodes. !!Default is None but the MAF will not work unless a seed is applied!! """ n_in: int # Size of the input vector n_cond: int # Size of the conditionning variable n_layers: int # Number of layers (i.e. number of stacked MADEs) layers: list[int] # list of hidden dimensionsin each MADE. activation: Callable # Activation function use_reverse: bool # Whether to reverse the order of the input between each MADE seed: Optional[int] = None # Random seed to label nodes
[docs] def setup(self): """Set the network creating the MAF layers.""" np.random.seed(self.seed) if self.n_in == 1: raise ValueError( "Flows can't be used to learn a one dimensional distribution. Consider using the `MixtureDensityNetwork`." ) layer_list = [] for _ in range(self.n_layers): layer_list.append( MAFLayer( n_in=self.n_in, n_cond=self.n_cond, hidden_dims=self.layers, reverse=self.use_reverse, seed=np.random.randint(0, 1000), activation=self.activation, ) ) self.layer_list = layer_list self.mean = jnp.zeros(self.n_in) self.cov = jnp.eye(self.n_in)
@nn.compact def __call__(self, x, y=None): """ Forward pass of the model. Returns mean and variance of the gaussian conditionals as well as the log-determinant of the Jacobian. Parameters ---------- x : jnp.Array Input vector. y : jnp.Array Conditionning variable?=. Returns ------- jnp.Array Transformed vector. jnp.Array Log-determinant of the Jacobian. """ log_det_sum = jnp.zeros(x.shape[0]) for layer in self.layer_list: x, log_det = layer.forward(x, y) log_det_sum += log_det # x = nn.BatchNorm(use_running_average=not train)(x) return x, log_det_sum
[docs] def backward(self, u, y=None): """ Backward pass of the model. Return vector x transformed by the inverse flow and the log-determinant of the Jacobian of the inverse flow. Parameters ---------- u : jnp.Array Input vector. y : jnp.Array Conditionning variable. Returns ------- x : jnp.Array Transformed vector. log_det_sum : jnp.Array Log-determinant of the Jacobian. """ log_det_sum = jnp.zeros(u.shape[0]) # backward pass for layer in reversed(self.layer_list): u, log_det = layer.backward(u, y) log_det_sum += log_det return u, log_det_sum
[docs] def log_prob(self, x, y=None): """ Compute the log-probability conditionned on some conditionning variable. Parameters ---------- x : jnp.Array Input vector. y : jnp.Array Conditionning variable. Returns ------- jnp.Array Log probability of the data point. """ u, log_det_sum = self.__call__(x, y) log_pdf = multivariate_normal.logpdf(u, self.mean, self.cov) return log_pdf + log_det_sum
[docs] def sample(self, y=None, num_samples=1, key=None): """ Sample from the distribution emulated by the neural network. Parameters ---------- y : jnp.Array Conditionning variable. num_samples : int Number of samples. key : jnp.Array Random key. Returns ------- jnp.Array Samples from the distribution. """ u = jax.random.multivariate_normal( key, self.mean, self.cov, shape=(num_samples,) ) if y is not None: y = y * jnp.ones((num_samples, 1)) x, _ = self.backward(u, y) return x
[docs] class NDE_Compressor(Compressor_w_NDE): """ Base class for a normalizing flow with a compressor. WARNING: This class will likely be removed in the future as it is obsolete. A general class to implement a compressor followed by a normalizing flow implementing standard methods to compute the log-probability of the target distribution or sample from it. """ compressor: nn.Module # Compressor network nde: NDENetwork # Normalizing Flow or Mixture Density network compressor_hparams: dict # Hyperparameters of the Neural Density Estimator nde_hparams: dict # Hyperparameters of the compressor
[docs] def setup(self): """Set the compressor and the normalizing flow.""" # Create models for the compressor and the NDE self.compressor_nn = self.compressor(**self.compressor_hparams) self.nde_nn = self.nde(**self.nde_hparams)
def __call__(self, x, y, model="NPE"): """ Perform a forward pass in the network and returns the log-probability of x given y. Parameters ---------- x : jnp.Array Data point y : jnp.Array Conditionning variable Returns ------- jnp.Array Log probability of the parameters y """ assert model in ["NPE", "NLE"], "Model should be either 'NPE' or 'NLE'." if model == "NPE": z = self.compressor_nn(y) return self.nde_nn.log_prob(x, z) else: z = self.compressor_nn(x) return self.nde_nn.log_prob(z, y)
[docs] def log_prob(self, x, y, model="NPE"): """ Return the log-probability of the parameters y conditioned by the data point x. Parameters ---------- x : jnp.Array Data point y : jnp.Array Conditionning variable Returns ------- jnp.Array Log probability of the parameters y """ return self.__call__(x, y, model)
[docs] def log_prob_compressed(self, z, y, model="NPE"): """ Return the log-probability of the compressed data z conditioned by the parameters y (if NPE). Parameters ---------- z : jnp.Array Compressed data point y : jnp.Array Conditionning variable Returns ------- jnp.Array Log probability of the parameters y """ assert model in ["NPE", "NLE"], "Model should be either 'NPE' or 'NLE'." if model == "NPE": return self.nde_nn.log_prob(y, z) else: return self.nde_nn.log_prob(z, y)
[docs] def sample(self, y, num_samples, key, model="NPE"): """ Sample from the distribution conditioned by y. Parameters ---------- y : jnp.Array Conditionning variable num_samples : int Number of samples key : jnp.Array Random key Returns ------- jnp.Array num_samples samples from the distribution """ assert model in ["NPE", "NLE"], "Model should be either 'NPE' or 'NLE'." if model == "NPE": z = self.compressor_nn(y) return self.nde_nn.sample(z, num_samples, key) else: return self.nde_nn.sample(y, num_samples, key)
[docs] class NDE_w_Standardization(NDENetwork): """ Base class to implement normalizing flow with a standardization step. This class creates an NDE network where the input data is first standardized. It takes in input a neural density estimator, an embedding net and a transformation. The embedding net is used to embed the data point in a latent space where the NDE is applied. It allows to compress the data to lower dimensional space. The transformation is used to transform to standardize the variable learned by the normalizing flow for stability purpose. """ nde: NDENetwork # Neural Density Estimator embedding_net: nn.Module # Embedding network transformation: distrax.Bijector # Transformation network TBC def __call__(self, x, y, model="NPE"): """ Return the log-probability of x given y for NPE and y given x for NLE. Parameters ---------- x : jnp.Array Parameters y : jnp.Array Conditionning variable model : str Whether the network is trained using NPE or NLE. Default: NPE. Returns ------- jnp.Array Log probability of the parameters y """ assert model in ["NPE", "NLE"], "Model should be either 'NPE' or 'NLE'." if model == "NLE": x, y = y, x # Learn the distribution p(y|x). Exchange the two. x, logprob_std = self.transformation.inverse_and_log_det(x) logprob_std = jnp.sum(logprob_std, axis=-1) z = self.embedding_net(y) log_prob = self.nde.log_prob(x, z) return log_prob + logprob_std
[docs] def standardize(self, x): """Standardize the data point x.""" return self.transformation.inverse(x)
[docs] def unstandardize(self, x): """Unstandardize the data point x.""" return self.transformation.forward(x)
[docs] def embedding(self, x): """Embed the data point x.""" return self.embedding_net(x)
[docs] def log_prob(self, x, y=None, model="NPE"): """Return the log probability of the data point x conditioned by y.""" return self.__call__(x, y, model)
[docs] def sample(self, y, num_samples, key, model="NPE"): """Sample from the distribution conditioned by y.""" assert model in ["NPE", "NLE"], "Model should be either 'NPE' or 'NLE'." if model == "NPE": z = self.embedding_net(y) samples = self.nde.sample(z, num_samples, key) else: samples = self.nde.sample(y, num_samples, key) samples = self.transformation.forward(samples) return samples