HomeGuidesRecipesChangelog
Guides

Working with checkpoints

Loading and navigating model's checkpoints for later usage

⚠️

Check This First!

This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.

Model checkpoints

When training either a foundation model or a scenario model, BaseModel enables checkpointing by specifying the checkpoint_dir field in the training configuration.

Checkpoint Contents

A checkpoint directory typically contains:

  • best_model.ckpt

    • Contains the weights from the best performing checkpoint (based on validation metrics).

    • Default metric_to_monitor if unspecified:

      TaskMetric
      MultiLabelClassificationAUROC
      MultiClassClassificationPrecision
      BinaryClassificationAUROC
      RegressionMean Squared Error
      RecommendationHitRate@10
    • For foundation models, this metric is auto-managed by BaseModel.

  • data.yaml – Configuration for data sources

  • dataloader.yaml – Data loader configuration

  • task.yaml – Task name and associated parameters

  • target_fn.bin – Serialized user-defined target function (applicable to scenario models checkpoints only)

  • lightning_checkpoints/ – raw checkpoint files from each epoch;

    best_model.ckpt is selected from these after training.


Loading from checkpoint.

Two loading scenarios are supported:

  1. Load a trained foundation model for scenario model fine-tuning.
  2. Load a trained scenario model for evaluation, inference, or further training.

Loading a trained Foundation Model.

Use load_from_foundation_model() to initialize a trainer using a foundation model checkpoint. This ensures correct adaptation of the network for downstream tasks:


from monad.ui.module import load_from_foundation_model

trainer = load_from_foundation_model(
        checkpoint_path="<path/to/store/pretrain/artifacts>",
        downstream_task=...,
        target_fn=...,
    )
Parameters
  • checkpoint_path: str | Path
    No default, required.
    Directory where the foundation model checkpoint artifacts are stored.

  • downstream_task: Task
    No default, required.
    Specifies the downstream task type β€” i.e., OneHotRecommendationTask(), RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask(), RegressionTask() .

  • target_fn: TargetFunction
    No default, required.
    User-defined function that computes training targets for the specified downstream task.

  • with_head: bool
    Default: False.
    Whether to use the last layer from the foundation model. May improve the quality of recommendation tasks.

  • pl_logger: Optional[Logger]
    Default: None.
    A PyTorch Lightning logger instance.

  • predictions_to_include_fn: [PredictionsFilteringFnType]
    Optional. Default: None.
    A function that returns items/classes the predictions should be narrowed to for each entity. Mutually exclusive with predictions_to_exclude_fn.

  • predictions_to_exclude_fn: [PredictionsFilteringFnType]
    Optional. Default: None.
    A function that returns items/classes that should be excluded from the predictions for each entity. Mutually exclusive with predictions_to_include_fn.

  • split: [TimeSplitOverride | EntitySplitOverride]
    Optional. Default: None.
    Optional override for the split configuration set during previous run.

    • TimeSplitOverride: a dict[DataMode, TimeRange], where
      DataMode ∈ {TRAIN, VALIDATION, TEST, PREDICT} and TimeRange has start_date: datetime and end_date: datetime. To override/add a test period, set DataMode.TEST and provide dates, e.g.: {DataMode.TEST: TimeRange(start_date=datetime(2023, 8, 1), end_date=datetime(2023, 8, 22))}

    • EntitySplitOverride: a Dict with fields
      training: int (percentage), validation: int (percentage), training_validation_end: datetime. The first two fields define train/validation percentages; training_validation_end marks the end of training & validation, leaving the remaining time for the test period.


Loading a fine-tuned scenario model

Use load_from_checkpoint() to restore a trained scenario model for evaluation, inference or continued training.

The example below demonstrates how to do that in the most basic way:


from monad.ui.module import load_from_checkpoint

trainer = load_from_checkpoint(
        checkpoint_path="<path/to/checkpoint>",
    )
Parameters
  • checkpoint_path: str
    No default, required.
    Directory where all the checkpoint artifacts are stored.

  • pl_logger: Optional[Logger]
    Default: None.
    An instance of PyTorch Lightning logger to use.

  • scoring: bool
    Default: False.
    Leave as False if the intention is to resume training, or set to True if it is to perform inference.

  • predictions_to_include_fn: [PredictionsFilteringFnType]
    Optional. Default: None.
    A function that returns items/classes the predictions should be narrowed to for each entity. Mutually exclusive with predictions_to_exclude_fn.

  • predictions_to_exclude_fn: [PredictionsFilteringFnType]
    Optional. Default: None.
    A function that returns items/classes that should be excluded from the predictions for each entity. Mutually exclusive with predictions_to_include_fn.

  • split: [TimeSplitOverride | EntitySplitOverride]
    Optional. Default: None.
    Optional override for the split configuration set during previous run.

    • TimeSplitOverride: a dict[DataMode, TimeRange], where
      DataMode ∈ {TRAIN, VALIDATION, TEST, PREDICT} and TimeRange has start_date: datetime and end_date: datetime. To override/add a test period, set DataMode.TEST and provide dates, e.g.: {DataMode.TEST: TimeRange(start_date=datetime(2023, 8, 1), end_date=datetime(2023, 8, 22))}

    • EntitySplitOverride: a Dict with fields
      training: int (percentage), validation: int (percentage), training_validation_end: datetime. The first two fields define train/validation percentages; training_validation_end marks the end of training & validation, leaving the remaining time for the test period.

    • Notes: This parameter is handy when you need to override or add split information (e.g., if the test period was not defined in the pretrain configuration file).
      However, when running test or prediction with models trained using an entity-based split, use prediction_date from TestingParams instead.

  • kwargs
    Default:{}.
    Data parameters to change.