Working with checkpoints
Loading and navigating model's checkpoints for later usage
Check This First!This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.
Model checkpoints
When training either a foundation model or a scenario model, BaseModel enables checkpointing by specifying the checkpoint_dir
field in the training configuration.
Checkpoint Contents
A checkpoint directory typically contains:
-
best_model.ckpt
-
Contains the weights from the best performing checkpoint (based on validation metrics).
-
Default
metric_to_monitor
if unspecified:Task Metric MultiLabelClassification AUROC MultiClassClassification Precision BinaryClassification AUROC Regression Mean Squared Error Recommendation HitRate@10 -
For foundation models, this metric is auto-managed by BaseModel.
-
-
data.yaml β Configuration for data sources
-
dataloader.yaml β Data loader configuration
-
task.yaml β Task name and associated parameters
-
target_fn.bin β Serialized user-defined target function (applicable to scenario models checkpoints only)
-
lightning_checkpoints/ β raw checkpoint files from each epoch;
best_model.ckpt
is selected from these after training.
Loading from checkpoint.
Two loading scenarios are supported:
- Load a trained foundation model for scenario model fine-tuning.
- Load a trained scenario model for evaluation, inference, or further training.
Loading a trained Foundation Model.
Use load_from_foundation_model()
to initialize a trainer using a foundation model checkpoint. This ensures correct adaptation of the network for downstream tasks:
from monad.ui.module import load_from_foundation_model
trainer = load_from_foundation_model(
checkpoint_path="<path/to/store/pretrain/artifacts>",
downstream_task=...,
target_fn=...,
)
Parameters |
---|
-
checkpoint_path:
str | Path
No default, required.
Directory where the foundation model checkpoint artifacts are stored. -
downstream_task:
Task
No default, required.
Specifies the downstream task type β i.e.,OneHotRecommendationTask()
,RecommendationTask()
,BinaryClassificationTask()
,MultilabelClassificationTask()
,MulticlassClassificationTask()
,RegressionTask()
. -
target_fn:
TargetFunction
No default, required.
User-defined function that computes training targets for the specified downstream task. -
with_head:
bool
Default: False.
Whether to use the last layer from the foundation model. May improve the quality of recommendation tasks. -
pl_logger:
Optional[Logger]
Default: None.
A PyTorch Lightning logger instance. -
predictions_to_include_fn:
[PredictionsFilteringFnType]
Optional. Default: None.
A function that returns items/classes the predictions should be narrowed to for each entity. Mutually exclusive withpredictions_to_exclude_fn
. -
predictions_to_exclude_fn:
[PredictionsFilteringFnType]
Optional. Default: None.
A function that returns items/classes that should be excluded from the predictions for each entity. Mutually exclusive withpredictions_to_include_fn
. -
split:
[TimeSplitOverride | EntitySplitOverride]
Optional. Default: None.
Optional override for the split configuration set during previous run.-
TimeSplitOverride: a
dict[DataMode, TimeRange]
, where
DataMode β {TRAIN, VALIDATION, TEST, PREDICT}
andTimeRange
hasstart_date: datetime
andend_date: datetime
. To override/add a test period, setDataMode.TEST
and provide dates, e.g.:{DataMode.TEST: TimeRange(start_date=datetime(2023, 8, 1), end_date=datetime(2023, 8, 22))}
-
EntitySplitOverride: a
Dict
with fields
training: int
(percentage),validation: int
(percentage),training_validation_end: datetime
. The first two fields define train/validation percentages;training_validation_end
marks the end of training & validation, leaving the remaining time for the test period.
-
Loading a fine-tuned scenario model
Use load_from_checkpoint()
to restore a trained scenario model for evaluation, inference or continued training.
The example below demonstrates how to do that in the most basic way:
from monad.ui.module import load_from_checkpoint
trainer = load_from_checkpoint(
checkpoint_path="<path/to/checkpoint>",
)
Parameters |
---|
-
checkpoint_path:
str
No default, required.
Directory where all the checkpoint artifacts are stored. -
pl_logger:
Optional[Logger]
Default: None.
An instance of PyTorch Lightning logger to use. -
scoring:
bool
Default: False.
Leave asFalse
if the intention is to resume training, or set toTrue
if it is to perform inference. -
predictions_to_include_fn:
[PredictionsFilteringFnType]
Optional. Default: None.
A function that returns items/classes the predictions should be narrowed to for each entity. Mutually exclusive withpredictions_to_exclude_fn
. -
predictions_to_exclude_fn:
[PredictionsFilteringFnType]
Optional. Default: None.
A function that returns items/classes that should be excluded from the predictions for each entity. Mutually exclusive withpredictions_to_include_fn
. -
split:
[TimeSplitOverride | EntitySplitOverride]
Optional. Default: None.
Optional override for the split configuration set during previous run.-
TimeSplitOverride: a
dict[DataMode, TimeRange]
, where
DataMode β {TRAIN, VALIDATION, TEST, PREDICT}
andTimeRange
hasstart_date: datetime
andend_date: datetime
. To override/add a test period, setDataMode.TEST
and provide dates, e.g.:{DataMode.TEST: TimeRange(start_date=datetime(2023, 8, 1), end_date=datetime(2023, 8, 22))}
-
EntitySplitOverride: a
Dict
with fields
training: int
(percentage),validation: int
(percentage),training_validation_end: datetime
. The first two fields define train/validation percentages;training_validation_end
marks the end of training & validation, leaving the remaining time for the test period. -
Notes: This parameter is handy when you need to override or add split information (e.g., if the test period was not defined in the pretrain configuration file).
However, when running test or prediction with models trained using an entity-based split, useprediction_date
fromTestingParams
instead.
-
-
kwargs
Default:{}
.
Data parameters to change.
Updated 22 days ago