jaxili.train module#

Train.

This module implements an object to perform the training of Normalizing Flows and more generally of neural networks.

class jaxili.train.TrainState(step: int | Array, apply_fn: Callable, params: FrozenDict[str, Any], tx: GradientTransformation, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], batch_stats: Any = (None,), rng: Any = None)[source]#

Bases: TrainState

A simple extension of TrainState to also include batch statistics.

If a model has no batch statistics, it is None. Keep an rng state for dropout or init.

Attributes:
rng

Methods

apply_gradients(*, grads, **kwargs)

Updates step, params, opt_state and **kwargs in return value.

create(*, apply_fn, params, tx, **kwargs)

Creates a new instance with step=0 and initialized opt_state.

replace(**updates)

Returns a new object replacing the specified fields with new values.

batch_stats: Any = (None,)#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

rng: Any = None#
class jaxili.train.TrainerModule(model_class: NDENetwork, model_hparams: Dict[str, Any], optimizer_hparams: Dict[str, Any], loss_fn: Callable, exmp_input: Any, seed: int = 42, logger_params: Dict[str, Any] = None, enable_progress_bar: bool = True, debug: bool = False, check_val_every_epoch: int = 1, nde_class: str = 'NPE', **kwargs)[source]#

Bases: object

A module to perform the training of Normalizing Flows.

This module contains the training loop, evaluation, logging, and checkpointing. It can also be used to load a model from a checkpoint.

Methods

bind_model()

Return a model with parameters bound to it.

create_functions()

Create and returns functions for the training and evaluation step.

create_jitted_functions()

Create jitted versions of the training, validation and evaluation functions.

eval_model(data_loader[, log_prefix])

Evaluate the model on a dataset.

generate_config(logger_params)

Generate a configuration dictionary for the trainer.

init_apply_fn()

Initialize a default apply function for the model.

init_checkpointer()

Initialize the checkpointer to save the model.

init_logger([logger_params])

Initialize the logger and created a logging directory.

init_model(exmp_input)

Create an initial training state with newly generated network parameters.

init_optimizer(num_epochs, num_steps_per_epoch)

Initialize the optimizer and learning rate scheduler.

is_new_model_better(new_metrics, old_metrics)

Compare two sets of evaluation metrics to decide whether the new model is better than the previous ones or not.

load_from_checkpoints(model_class, ...)

Create a Trainer object with same hyperparameters and loaded model from a checkpoint directory.

load_model()

Load model and batch statistics from the logging directory.

on_training_epoch_end(epoch_idx)

Perform any necessary operations at the end of each training epoch.

on_training_start()

Perform any necessary operations before the training starts.

on_validation_epoch_end(epoch_idx, ...)

Perform any necessary operations at the end of each validation epoch.

print_tabulate(exmp_input)

Print a summary of the model represented as a table.

run_model_init(exmp_input, init_rng)

Initialize the model by calling it on the example input.

save_metrics(filename, metrics)

Save a dictionary of metrics to file.

save_model([step])

Save current training state at certain training iteration.

tracker(iterator, **kwargs)

Wrap an iterator in a progress bar track (tqdm) if the progress bar is enabled.

train_epoch(train_loader)

Train the model for one epoch.

train_model(train_loader, val_loader[, ...])

Start a training loop for the given number of epochs.

write_config(log_dir)

Write the config of the trainer in a JSON file.

bind_model()[source]#

Return a model with parameters bound to it. Enables an easier inference access.

Return type:

The model with parameters and evt. batch statistics bound to it.

create_functions() Tuple[Callable[[TrainState, Any], Tuple[TrainState, Dict]], Callable[[TrainState, Any], Tuple[TrainState, Dict]]][source]#

Create and returns functions for the training and evaluation step.

The functions take as input the training state and a batch from the train/val/test loader. Both functions are expected to return a dictionary of logging metrics, and the training function a new train state. This function can be overwritten by a subclass. The train_step and eval_step functions here are examples for the signature of the functions.

create_jitted_functions()[source]#

Create jitted versions of the training, validation and evaluation functions.

If self.debug is True, no jitting is applied.

eval_model(data_loader: Iterator, log_prefix: str | None = '') Dict[str, Any][source]#

Evaluate the model on a dataset.

Parameters:
  • data_loader (Iterator) – An iterator over the data.

  • log_prefix (str) – A prefix to add to all metrics.

Returns:

A dictionary of the evaluation metrics, averaged over data points in the dataset

Return type:

Dict[str, Any]

generate_config(logger_params)[source]#

Generate a configuration dictionary for the trainer.

init_apply_fn()[source]#

Initialize a default apply function for the model.

init_checkpointer()[source]#

Initialize the checkpointer to save the model.

init_logger(logger_params: Dict | None = None)[source]#

Initialize the logger and created a logging directory.

Parameters:

logger_params (Dict[str, Any]) – A dictionary containing the specifications of the logger.

init_model(exmp_input: Any)[source]#

Create an initial training state with newly generated network parameters.

Parameters:

exmp_input (Any) – An input to the model with which the shapes are inferred.

init_optimizer(num_epochs: int, num_steps_per_epoch: int)[source]#

Initialize the optimizer and learning rate scheduler.

Parameters:
  • num_epochs (int) – Number of epochs to train.

  • num_steps_per_epoch (int) – Number of steps per epoch.

is_new_model_better(new_metrics: Dict[str, Any], old_metrics: Dict[str, Any]) bool[source]#

Compare two sets of evaluation metrics to decide whether the new model is better than the previous ones or not.

Parameters:
  • new_metrics (Iterator) – A dictionary of the evaluation metrics of the new model

  • old_metrics (Iterator) – A dictionary of the evaluation metrics of the previously best model.

Returns:

True if the new model is better, False otherwise.

Return type:

bool

classmethod load_from_checkpoints(model_class: NDENetwork, checkpoint: str, exmp_input: Any) Any[source]#

Create a Trainer object with same hyperparameters and loaded model from a checkpoint directory.

Parameters:
  • model_class (jaxili.model.NDENetwork) – The class of the model that should be loaded.

  • checkpoint (str) – Folder in which the checkpoint and hyperparameter file is stored

  • exmp_input (Any) – An input to the model with which the shapes are inferred.

Return type:

A Trainer object with model loaded from the checkpoint folder.

load_model()[source]#

Load model and batch statistics from the logging directory.

on_training_epoch_end(epoch_idx: int)[source]#

Perform any necessary operations at the end of each training epoch.

Method called at the end of each training epoch. Can be used for additional logging or similar.

Parameters:

epoch_idx (int) – Index of the epoch that just finished.

on_training_start()[source]#

Perform any necessary operations before the training starts.

Method called before training is started. Can be used for additional initialization operations etc.

on_validation_epoch_end(epoch_idx: int, eval_metrics: Dict[str, Any], val_loader: Iterator)[source]#

Perform any necessary operations at the end of each validation epoch.

Method called at the end of each validation epoch. Can be used for additional logging and evaluation.

Parameters:
  • epoch_idx (int) – Index of the epoch that just finished.

  • eval_metrics (Dict[str, Any]) – A dictionary of the evaluation metrics.

  • val_loader (Iterator) – DataLoader of the validation set to support additional evaluation.

print_tabulate(exmp_input: Any)[source]#

Print a summary of the model represented as a table.

Parameters:

exmp_input (Any) – An input to the model with which the shapes are inferred.

run_model_init(exmp_input: Any, init_rng: Any) Dict[source]#

Initialize the model by calling it on the example input.

Parameters:
  • exmp_input (Dict[str, Any]) – An input to the model with which the shapes are inferred.

  • init_rng (Array) – A jax.random.PRNGKey

Return type:

The initialized variable dictionary.

save_metrics(filename: str, metrics: Dict[str, Any])[source]#

Save a dictionary of metrics to file.

This can be used as a textual representation of the validation performance for checking in the terminal.

Parameters:
  • filename (str) – Name of the metrics file without folders and postfix

  • metrics (Dict[str, Any]) – A dictionary of the metrics to save

save_model(step: int = 0)[source]#

Save current training state at certain training iteration.

Only the model parameters and batch statistics are saved to reduce memory footprint. To allow the training to be continued from a checkpoint, this method can be extended to include the optimizer state as well.

Parameters:

step (int) – Index of the step to save the model at, e.g. epoch.

tracker(iterator: Iterator, **kwargs) Iterator[source]#

Wrap an iterator in a progress bar track (tqdm) if the progress bar is enabled.

Parameters:
  • iterator (Iterator) – Iterator to wrap in tqdm.

  • kwargs (Any) – additional arguments to tqdm.

Returns:

Wrapped iterator if progress bar is enabled, otherwise same iterator than input.

Return type:

Iterator

train_epoch(train_loader: Iterator) Dict[str, Any][source]#

Train the model for one epoch.

Parameters:

train_loader (Iterator) – An iterator over the training data.

Returns:

A dictionary of the average training metrics over all batches for logging

Return type:

Dict[str, Any]

train_model(train_loader: Iterator, val_loader: Iterator, test_loader: Iterator | None = None, num_epochs: int = 500, min_delta: float = 0.001, patience: int = 20) Dict[str, Any][source]#

Start a training loop for the given number of epochs.

Parameters:
  • train_loader (Iterator) – An iterator over the training data.

  • val_loader (Iterator) – An iterator over the validation data.

  • test_loader (Iterator) – If given, best model will be evaluated on the test set.

  • num_epochs (int) – Number of epochs for which to train the model.

  • min_delta (float) – Minimum change in the monitored metric to qualify as an improvement.

  • patience (int) – Number of epochs with no improvement after which training will be stopped. Default is 20.

Returns:

A dictionary of the train, validation and evt. test metrics for the best model on the validation set.

Return type:

Dict[str, Any]

write_config(log_dir)[source]#

Write the config of the trainer in a JSON file.