jaxili.train module#
Train.
This module implements an object to perform the training of Normalizing Flows and more generally of neural networks.
- class jaxili.train.TrainState(step: int | Array, apply_fn: Callable, params: FrozenDict[str, Any], tx: GradientTransformation, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], batch_stats: Any = (None,), rng: Any = None)[source]#
Bases:
TrainStateA simple extension of TrainState to also include batch statistics.
If a model has no batch statistics, it is None. Keep an rng state for dropout or init.
- Attributes:
- rng
Methods
apply_gradients(*, grads, **kwargs)Updates
step,params,opt_stateand**kwargsin return value.create(*, apply_fn, params, tx, **kwargs)Creates a new instance with
step=0and initializedopt_state.replace(**updates)Returns a new object replacing the specified fields with new values.
- batch_stats: Any = (None,)#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- rng: Any = None#
- class jaxili.train.TrainerModule(model_class: NDENetwork, model_hparams: Dict[str, Any], optimizer_hparams: Dict[str, Any], loss_fn: Callable, exmp_input: Any, seed: int = 42, logger_params: Dict[str, Any] = None, enable_progress_bar: bool = True, debug: bool = False, check_val_every_epoch: int = 1, nde_class: str = 'NPE', **kwargs)[source]#
Bases:
objectA module to perform the training of Normalizing Flows.
This module contains the training loop, evaluation, logging, and checkpointing. It can also be used to load a model from a checkpoint.
Methods
Return a model with parameters bound to it.
Create and returns functions for the training and evaluation step.
Create jitted versions of the training, validation and evaluation functions.
eval_model(data_loader[, log_prefix])Evaluate the model on a dataset.
generate_config(logger_params)Generate a configuration dictionary for the trainer.
Initialize a default apply function for the model.
Initialize the checkpointer to save the model.
init_logger([logger_params])Initialize the logger and created a logging directory.
init_model(exmp_input)Create an initial training state with newly generated network parameters.
init_optimizer(num_epochs, num_steps_per_epoch)Initialize the optimizer and learning rate scheduler.
is_new_model_better(new_metrics, old_metrics)Compare two sets of evaluation metrics to decide whether the new model is better than the previous ones or not.
load_from_checkpoints(model_class, ...)Create a Trainer object with same hyperparameters and loaded model from a checkpoint directory.
Load model and batch statistics from the logging directory.
on_training_epoch_end(epoch_idx)Perform any necessary operations at the end of each training epoch.
Perform any necessary operations before the training starts.
on_validation_epoch_end(epoch_idx, ...)Perform any necessary operations at the end of each validation epoch.
print_tabulate(exmp_input)Print a summary of the model represented as a table.
run_model_init(exmp_input, init_rng)Initialize the model by calling it on the example input.
save_metrics(filename, metrics)Save a dictionary of metrics to file.
save_model([step])Save current training state at certain training iteration.
tracker(iterator, **kwargs)Wrap an iterator in a progress bar track (tqdm) if the progress bar is enabled.
train_epoch(train_loader)Train the model for one epoch.
train_model(train_loader, val_loader[, ...])Start a training loop for the given number of epochs.
write_config(log_dir)Write the config of the trainer in a JSON file.
- bind_model()[source]#
Return a model with parameters bound to it. Enables an easier inference access.
- Return type:
The model with parameters and evt. batch statistics bound to it.
- create_functions() Tuple[Callable[[TrainState, Any], Tuple[TrainState, Dict]], Callable[[TrainState, Any], Tuple[TrainState, Dict]]][source]#
Create and returns functions for the training and evaluation step.
The functions take as input the training state and a batch from the train/val/test loader. Both functions are expected to return a dictionary of logging metrics, and the training function a new train state. This function can be overwritten by a subclass. The train_step and eval_step functions here are examples for the signature of the functions.
- create_jitted_functions()[source]#
Create jitted versions of the training, validation and evaluation functions.
If self.debug is True, no jitting is applied.
- eval_model(data_loader: Iterator, log_prefix: str | None = '') Dict[str, Any][source]#
Evaluate the model on a dataset.
- Parameters:
data_loader (Iterator) – An iterator over the data.
log_prefix (str) – A prefix to add to all metrics.
- Returns:
A dictionary of the evaluation metrics, averaged over data points in the dataset
- Return type:
Dict[str, Any]
- init_logger(logger_params: Dict | None = None)[source]#
Initialize the logger and created a logging directory.
- Parameters:
logger_params (Dict[str, Any]) – A dictionary containing the specifications of the logger.
- init_model(exmp_input: Any)[source]#
Create an initial training state with newly generated network parameters.
- Parameters:
exmp_input (Any) – An input to the model with which the shapes are inferred.
- init_optimizer(num_epochs: int, num_steps_per_epoch: int)[source]#
Initialize the optimizer and learning rate scheduler.
- Parameters:
num_epochs (int) – Number of epochs to train.
num_steps_per_epoch (int) – Number of steps per epoch.
- is_new_model_better(new_metrics: Dict[str, Any], old_metrics: Dict[str, Any]) bool[source]#
Compare two sets of evaluation metrics to decide whether the new model is better than the previous ones or not.
- Parameters:
new_metrics (Iterator) – A dictionary of the evaluation metrics of the new model
old_metrics (Iterator) – A dictionary of the evaluation metrics of the previously best model.
- Returns:
True if the new model is better, False otherwise.
- Return type:
bool
- classmethod load_from_checkpoints(model_class: NDENetwork, checkpoint: str, exmp_input: Any) Any[source]#
Create a Trainer object with same hyperparameters and loaded model from a checkpoint directory.
- Parameters:
model_class (jaxili.model.NDENetwork) – The class of the model that should be loaded.
checkpoint (str) – Folder in which the checkpoint and hyperparameter file is stored
exmp_input (Any) – An input to the model with which the shapes are inferred.
- Return type:
A Trainer object with model loaded from the checkpoint folder.
- on_training_epoch_end(epoch_idx: int)[source]#
Perform any necessary operations at the end of each training epoch.
Method called at the end of each training epoch. Can be used for additional logging or similar.
- Parameters:
epoch_idx (int) – Index of the epoch that just finished.
- on_training_start()[source]#
Perform any necessary operations before the training starts.
Method called before training is started. Can be used for additional initialization operations etc.
- on_validation_epoch_end(epoch_idx: int, eval_metrics: Dict[str, Any], val_loader: Iterator)[source]#
Perform any necessary operations at the end of each validation epoch.
Method called at the end of each validation epoch. Can be used for additional logging and evaluation.
- Parameters:
epoch_idx (int) – Index of the epoch that just finished.
eval_metrics (Dict[str, Any]) – A dictionary of the evaluation metrics.
val_loader (Iterator) – DataLoader of the validation set to support additional evaluation.
- print_tabulate(exmp_input: Any)[source]#
Print a summary of the model represented as a table.
- Parameters:
exmp_input (Any) – An input to the model with which the shapes are inferred.
- run_model_init(exmp_input: Any, init_rng: Any) Dict[source]#
Initialize the model by calling it on the example input.
- Parameters:
exmp_input (Dict[str, Any]) – An input to the model with which the shapes are inferred.
init_rng (Array) – A jax.random.PRNGKey
- Return type:
The initialized variable dictionary.
- save_metrics(filename: str, metrics: Dict[str, Any])[source]#
Save a dictionary of metrics to file.
This can be used as a textual representation of the validation performance for checking in the terminal.
- Parameters:
filename (str) – Name of the metrics file without folders and postfix
metrics (Dict[str, Any]) – A dictionary of the metrics to save
- save_model(step: int = 0)[source]#
Save current training state at certain training iteration.
Only the model parameters and batch statistics are saved to reduce memory footprint. To allow the training to be continued from a checkpoint, this method can be extended to include the optimizer state as well.
- Parameters:
step (int) – Index of the step to save the model at, e.g. epoch.
- tracker(iterator: Iterator, **kwargs) Iterator[source]#
Wrap an iterator in a progress bar track (tqdm) if the progress bar is enabled.
- Parameters:
iterator (Iterator) – Iterator to wrap in tqdm.
kwargs (Any) – additional arguments to tqdm.
- Returns:
Wrapped iterator if progress bar is enabled, otherwise same iterator than input.
- Return type:
Iterator
- train_epoch(train_loader: Iterator) Dict[str, Any][source]#
Train the model for one epoch.
- Parameters:
train_loader (Iterator) – An iterator over the training data.
- Returns:
A dictionary of the average training metrics over all batches for logging
- Return type:
Dict[str, Any]
- train_model(train_loader: Iterator, val_loader: Iterator, test_loader: Iterator | None = None, num_epochs: int = 500, min_delta: float = 0.001, patience: int = 20) Dict[str, Any][source]#
Start a training loop for the given number of epochs.
- Parameters:
train_loader (Iterator) – An iterator over the training data.
val_loader (Iterator) – An iterator over the validation data.
test_loader (Iterator) – If given, best model will be evaluated on the test set.
num_epochs (int) – Number of epochs for which to train the model.
min_delta (float) – Minimum change in the monitored metric to qualify as an improvement.
patience (int) – Number of epochs with no improvement after which training will be stopped. Default is 20.
- Returns:
A dictionary of the train, validation and evt. test metrics for the best model on the validation set.
- Return type:
Dict[str, Any]