Using Checkpoints

Model checkpoints

When training both the Foundation Model and the Downstream Model, BaseModel gives the possibility to save the model checkpoint. To enable checkpointing, you must provide checkpoint_dir field in the training parameters of the Foundation or Downstream Model configuration.

The structure of a checkpoint is following:

  • best_model.ckpt - Stores the weights of the best checkpoint registered at each validation epoch. "Goodness" of the model is measured by the value of the metric_to_monitor specified in the training configuration. If not provided, the default metric for each task is used. Those defaults are:

    TaskMetric
    MultiLabelClassificationAUROC
    MultiClassClassificationPrecision
    BinaryClassificationAUROC
    RegressionMean Squared Error
    RecomendationHitRate@10

    In case of a Foundation Model metric_to_monitor is set by the BaseModel and doesn't need to be of the users concern.

  • data.yaml - Stores the configuration of the datasources

  • dataloader.yaml - Stores the configuration of the dataloader

  • task.yaml - Stores the name of the task with some task-specific model parameters.

  • target_fn.bin - Stores the code of a target function, provided by the user to train downstream model, in a binary format. Does not apply to a Foundation Model checkpoint.

  • lightning_checkpoints - Directory containing model weights saved after every epoch. At the end of the training best model is selected and saved as best_model.ckpt

Loading from checkpoint.

There are 2 methods for loading saved checkpoint.

  1. Loading a trained Foundation Model to train a Downstream Model.
  2. Loading a trained Downstream Model to do predictions or to further train the model.

Loading a trained Foundation Model.

It is essential to use this method when you want to train a Downsream task since it performs a necessary validation and modifies the underlying neural network accordingly.

from monad.ui.module import load_from_foundation_model

trainer = load_from_foundation_model(
        checkpoint_path="<path/to/store/pretrain/artifacts>",
        downstream_task=...,
        target_fn=...,
    )

Creates MonadModuleImpl from MonadCheckpoint where model saved under checkpoint_path is assumed to be a Foundation Model.

Parameters

NameTypeDescriptionDefault
checkpoint_pathstrDirectory where all the checkpoint artifacts are stored.required
downstream_taskTaskOne of the machine learning tasks defined in BaseModel. Possible values are RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask()required
target_fnCallable[[Events, Events, Attributes, Dict], Union[Tensor, ndarray, Sketch]]Target function for the specified task.required
pl_loggerOptional[Logger]Instance of PytorchLightning logger.None
loading_configOptional[LoadingConfigParams]A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed.None

Returns

NameTypeDescription
MonadModuleImplMonadModuleImplInstance of monad module, loaded form the checkpoint.

Loading a trained Downstream model.

from monad.ui.module import load_from_checkpoint

trainer = load_from_checkpoint(
        checkpoint_path="<path/to/checkpoint>",
    )

Creates MonadModuleImpl from MonadCheckpoint.

Parameters

NameTypeDescriptionDefault
checkpoint_pathstrDirectory where all the checkpoint artifacts are stored.required
pl_loggerOptional[Logger]An instance of PyTorch Lightning logger to use.None
loading_configOptional[LoadingConfigParams]A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed.None
kwargsData parameters to change.{}

Returns

NameTypeDescription
MonadModuleImplMonadModuleImplInstance of monad module, loaded form the checkpoint.