Working with checkpoints

Loading and navigating model's checkpoints for later usage

Model checkpoints

When training both the Foundation Model and the Downstream Model, BaseModel gives the possibility to save the model checkpoint. To enable checkpointing, you must provide checkpoint_dir field in the training parameters of the Foundation or Downstream Model configuration.

The structure of a checkpoint is following:

  • best_model.ckpt - Stores the weights of the best checkpoint registered at each validation epoch. "Goodness" of the model is measured by the value of the metric_to_monitor specified in the training configuration. If not provided, the default metric for each task is used. Those defaults are:

    TaskMetric
    MultiLabelClassificationAUROC
    MultiClassClassificationPrecision
    BinaryClassificationAUROC
    RegressionMean Squared Error
    RecomendationHitRate@10

    In case of a Foundation Model metric_to_monitor is set by the BaseModel and doesn't need to be of the users concern.

  • data.yaml - Stores the configuration of the datasources

  • dataloader.yaml - Stores the configuration of the dataloader

  • task.yaml - Stores the name of the task with some task-specific model parameters.

  • target_fn.bin - Stores the code of a target function, provided by the user to train downstream model, in a binary format. Does not apply to a Foundation Model checkpoint.

  • lightning_checkpoints - Directory containing model weights saved after every epoch. At the end of the training best model is selected and saved as best_model.ckpt

Loading from checkpoint.

There are 2 methods for loading saved checkpoint.

  1. Loading a trained Foundation Model to train a Downstream Model.
  2. Loading a trained Downstream Model to do predictions or to further train the model.

Loading a trained Foundation Model.

It is essential to use this method when you want to train a Downsream task since it performs a necessary validation and modifies the underlying neural network accordingly.

from monad.ui.module import load_from_foundation_model

trainer = load_from_foundation_model(
        checkpoint_path="<path/to/store/pretrain/artifacts>",
        downstream_task=...,
        target_fn=...,
    )

Creates MonadModuleImpl from MonadCheckpoint where model saved under checkpoint_path is assumed to be a Foundation Model.

Parameters

NameTypeDescriptionDefault
checkpoint_pathstrDirectory where all the checkpoint artifacts are stored.required
downstream_taskTaskOne of the machine learning tasks defined in BaseModel. Possible values are RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask()required
target_fnCallable[[Events, Events, Attributes, Dict], Union[Tensor, ndarray, Sketch]]Target function for the specified task.required
pl_loggerOptional[Logger]Instance of PytorchLightning logger.None
loading_configOptional[LoadingConfigParams]A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed.None

Returns

NameTypeDescription
MonadModuleImplMonadModuleImplInstance of monad module, loaded form the checkpoint.

Loading a trained Downstream model.

from monad.ui.module import load_from_checkpoint

trainer = load_from_checkpoint(
        checkpoint_path="<path/to/checkpoint>",
    )

Creates MonadModuleImpl from MonadCheckpoint.

Parameters

NameTypeDescriptionDefault
checkpoint_pathstrDirectory where all the checkpoint artifacts are stored.required
pl_loggerOptional[Logger]An instance of PyTorch Lightning logger to use.None
loading_configOptional[LoadingConfigParams]A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed.None
kwargsData parameters to change.{}

Returns

NameTypeDescription
MonadModuleImplMonadModuleImplInstance of monad module, loaded form the checkpoint.