Using Checkpoints
Model checkpoints
When training both the Foundation Model and the Downstream Model, BaseModel gives the possibility to save the model checkpoint. To enable checkpointing, you must provide checkpoint_dir
field in the training parameters of the Foundation or Downstream Model configuration.
The structure of a checkpoint is following:
-
best_model.ckpt - Stores the weights of the best checkpoint registered at each validation epoch. "Goodness" of the model is measured by the value of the
metric_to_monitor
specified in the training configuration. If not provided, the default metric for each task is used. Those defaults are:Task Metric MultiLabelClassification AUROC MultiClassClassification Precision BinaryClassification AUROC Regression Mean Squared Error Recomendation HitRate@10 In case of a Foundation Model
metric_to_monitor
is set by the BaseModel and doesn't need to be of the users concern. -
data.yaml - Stores the configuration of the datasources
-
dataloader.yaml - Stores the configuration of the dataloader
-
task.yaml - Stores the name of the task with some task-specific model parameters.
-
target_fn.bin - Stores the code of a target function, provided by the user to train downstream model, in a binary format. Does not apply to a Foundation Model checkpoint.
-
lightning_checkpoints - Directory containing model weights saved after every epoch. At the end of the training best model is selected and saved as best_model.ckpt
Loading from checkpoint.
There are 2 methods for loading saved checkpoint.
- Loading a trained Foundation Model to train a Downstream Model.
- Loading a trained Downstream Model to do predictions or to further train the model.
Loading a trained Foundation Model.
It is essential to use this method when you want to train a Downsream task since it performs a necessary validation and modifies the underlying neural network accordingly.
from monad.ui.module import load_from_foundation_model
trainer = load_from_foundation_model(
checkpoint_path="<path/to/store/pretrain/artifacts>",
downstream_task=...,
target_fn=...,
)
Creates MonadModuleImpl from MonadCheckpoint where model saved under checkpoint_path is assumed to be a Foundation Model.
Parameters
Name | Type | Description | Default |
---|---|---|---|
checkpoint_path | str | Directory where all the checkpoint artifacts are stored. | required |
downstream_task | Task | One of the machine learning tasks defined in BaseModel. Possible values are RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask() | required |
target_fn | Callable[[Events, Events, Attributes, Dict], Union[Tensor, ndarray, Sketch]] | Target function for the specified task. | required |
pl_logger | Optional[Logger] | Instance of PytorchLightning logger. | None |
loading_config | Optional[LoadingConfigParams] | A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed. | None |
Returns
Name | Type | Description |
---|---|---|
MonadModuleImpl | MonadModuleImpl | Instance of monad module, loaded form the checkpoint. |
Loading a trained Downstream model.
from monad.ui.module import load_from_checkpoint
trainer = load_from_checkpoint(
checkpoint_path="<path/to/checkpoint>",
)
Creates MonadModuleImpl from MonadCheckpoint.
Parameters
Name | Type | Description | Default |
---|---|---|---|
checkpoint_path | str | Directory where all the checkpoint artifacts are stored. | required |
pl_logger | Optional[Logger] | An instance of PyTorch Lightning logger to use. | None |
loading_config | Optional[LoadingConfigParams] | A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed. | None |
kwargs | Data parameters to change. | {} |
Returns
Name | Type | Description |
---|---|---|
MonadModuleImpl | MonadModuleImpl | Instance of monad module, loaded form the checkpoint. |
Updated 4 months ago