Working with checkpoints
Loading and navigating model's checkpoints for later usage
Model checkpoints
When training both the Foundation Model and the Downstream Model, BaseModel gives the possibility to save the model checkpoint. To enable checkpointing, you must provide checkpoint_dir
field in the training parameters of the Foundation or Downstream Model configuration.
The structure of a checkpoint is following:
-
best_model.ckpt - Stores the weights of the best checkpoint registered at each validation epoch. "Goodness" of the model is measured by the value of the
metric_to_monitor
specified in the training configuration. If not provided, the default metric for each task is used. Those defaults are:Task Metric MultiLabelClassification AUROC MultiClassClassification Precision BinaryClassification AUROC Regression Mean Squared Error Recomendation HitRate@10 In case of a Foundation Model
metric_to_monitor
is set by the BaseModel and doesn't need to be of the users concern. -
data.yaml - Stores the configuration of the datasources
-
dataloader.yaml - Stores the configuration of the dataloader
-
task.yaml - Stores the name of the task with some task-specific model parameters.
-
target_fn.bin - Stores the code of a target function, provided by the user to train downstream model, in a binary format. Does not apply to a Foundation Model checkpoint.
-
lightning_checkpoints - Directory containing model weights saved after every epoch. At the end of the training best model is selected and saved as best_model.ckpt
Loading from checkpoint.
There are 2 methods for loading saved checkpoint.
- Loading a trained Foundation Model to train a Downstream Model.
- Loading a trained Downstream Model to do predictions or to further train the model.
Loading a trained Foundation Model.
It is essential to use this method when you want to train a Downsream task since it performs a necessary validation and modifies the underlying neural network accordingly.
from monad.ui.module import load_from_foundation_model
trainer = load_from_foundation_model(
checkpoint_path="<path/to/store/pretrain/artifacts>",
downstream_task=...,
target_fn=...,
)
Creates MonadModuleImpl from MonadCheckpoint where model saved under checkpoint_path is assumed to be a Foundation Model.
Parameters
Name | Type | Description | Default |
---|---|---|---|
checkpoint_path | str | Directory where all the checkpoint artifacts are stored. | required |
downstream_task | Task | One of the machine learning tasks defined in BaseModel. Possible values are RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask() | required |
target_fn | Callable[[Events, Events, Attributes, Dict], Union[Tensor, ndarray, Sketch]] | Target function for the specified task. | required |
pl_logger | Optional[Logger] | Instance of PytorchLightning logger. | None |
loading_config | Optional[LoadingConfigParams] | A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed. | None |
Returns
Name | Type | Description |
---|---|---|
MonadModuleImpl | MonadModuleImpl | Instance of monad module, loaded form the checkpoint. |
Loading a trained Downstream model.
from monad.ui.module import load_from_checkpoint
trainer = load_from_checkpoint(
checkpoint_path="<path/to/checkpoint>",
)
Creates MonadModuleImpl from MonadCheckpoint.
Parameters
Name | Type | Description | Default |
---|---|---|---|
checkpoint_path | str | Directory where all the checkpoint artifacts are stored. | required |
pl_logger | Optional[Logger] | An instance of PyTorch Lightning logger to use. | None |
loading_config | Optional[LoadingConfigParams] | A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed. | None |
kwargs | Data parameters to change. | {} |
Returns
Name | Type | Description |
---|---|---|
MonadModuleImpl | MonadModuleImpl | Instance of monad module, loaded form the checkpoint. |
Updated 15 days ago