Skip to content

Scenario Model API

A scenario model fine-tunes the foundation model for a specific prediction task. Unlike the foundation model (configured via YAML), scenario models are configured entirely via the Python API.

load_from_foundation_model()

Creates a scenario model from a trained foundation model checkpoint.

from pathlib import Path
from monad.ui.module import load_from_foundation_model, BinaryClassificationTask

module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=BinaryClassificationTask(),
    target_fn=my_target_fn,
)

Parameters

Parameter Type Default Description
checkpoint_path str \| Path required Path to the foundation model checkpoint directory.
downstream_task Task required The prediction task type. See Task Types below.
target_fn TargetFunction required Target function defining what to predict. See Target Function Reference.
with_head bool False Whether to reuse the foundation model head (advanced).
pl_logger Logger \| None None PyTorch Lightning logger instance.
worker_init_fn Callable[[int], None] \| None None Custom worker initialization function for data loading.
predictions_to_include_fn PredictionsFilteringFnType \| None None Function returning items/classes to include in predictions. Mutually exclusive with predictions_to_exclude_fn.
predictions_to_exclude_fn PredictionsFilteringFnType \| None None Function returning items/classes to exclude from predictions. Mutually exclusive with predictions_to_include_fn.
split TimeSplitOverride \| EntitySplitOverride \| None None Override the data split from the foundation model checkpoint.
**kwargs Additional keyword arguments forwarded as data parameter overrides (e.g., data_params, query_optimization). See Loading Overrides.

Note

predictions_to_include_fn and predictions_to_exclude_fn are only supported for RecommendationTask, MulticlassClassificationTask, and MultilabelClassificationTask.


load_from_checkpoint()

Loads a previously trained scenario model for resumed training, testing, or inference.

from pathlib import Path
from monad.ui.module import load_from_checkpoint

module = load_from_checkpoint(
    checkpoint_path=Path("./my_scenario_model"),
)

Parameters

Parameter Type Default Description
checkpoint_path str \| Path required Path to the scenario model checkpoint directory.
pl_logger Logger \| None None PyTorch Lightning logger instance.
scoring bool False Whether loading for inference (True) or resumed training (False).
worker_init_fn Callable[[int], None] \| None None Custom worker initialization function.
split TimeSplitOverride \| EntitySplitOverride \| None None Override the data split from the checkpoint.
predictions_to_include_fn PredictionsFilteringFnType \| None None Items/classes to include in predictions.
predictions_to_exclude_fn PredictionsFilteringFnType \| None None Items/classes to exclude from predictions.
**kwargs Additional keyword arguments forwarded as data parameter overrides. See Loading Overrides.

Which loading function to use

  • load_from_foundation_model() — Use when creating a new scenario model from a foundation model checkpoint. The checkpoint_path must point to the inner fm/ subdirectory (e.g., output_path/fm/), not the outer output directory.
  • load_from_checkpoint() — Use when loading an already-trained scenario model for test, predict, or interpret. Using load_from_foundation_model() for inference re-initializes the scenario head and discards trained weights.

Task Types

Each scenario model requires a task type that defines the prediction objective.

from monad.ui.module import (
    BinaryClassificationTask,
    MulticlassClassificationTask,
    MultilabelClassificationTask,
    RegressionTask,
    RecommendationTask,
    OneHotRecommendationTask,
)
Task Constructor Description
BinaryClassificationTask() No parameters Predict yes/no outcomes (e.g., churn, conversion).
MulticlassClassificationTask(class_names=[...]) class_names: list[str] Predict one of N mutually exclusive classes.
MultilabelClassificationTask(class_names=[...]) class_names: list[str] Predict multiple labels simultaneously.
RegressionTask(num_targets=..., max_value=2000000.0) num_targets: int (required), max_value: float Predict continuous values (e.g., spend, LTV).
RecommendationTask() No parameters Rank items for each entity. Default for high-cardinality catalogs.
OneHotRecommendationTask() No parameters Rank items for each entity. For low-cardinality item sets.

Workflow Example

Complete training and inference pipeline:

from pathlib import Path
import numpy as np
from datetime import timedelta

from monad.ui.module import load_from_foundation_model, BinaryClassificationTask
from monad.config import TrainingParams, TestingParams, OutputType, MetricParams
from monad.targets import Events, Attributes
from monad.batch import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window


# 1. Define target function
def churn_target_fn(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() < 2:
        return None
    if has_incomplete_training_window(ctx, timedelta(days=30)):
        return None
    future_30d = future.interval_from(ctx[SPLIT_TIMESTAMP], timedelta(days=30))
    churned = 1 if future_30d["transactions"].count() == 0 else 0
    return np.array([churned], dtype=np.float32)


# 2. Load foundation model and define task
module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=BinaryClassificationTask(),
    target_fn=churn_target_fn,
)

# 3. Train
training_params = TrainingParams(
    epochs=5,
    learning_rate=1e-4,
    devices=[0],
    checkpoint_dir=Path("./churn_model"),
    metrics=[MetricParams(alias="auroc", metric_name="AUROC")],
)
module.fit(training_params)

# 4. Test on test split (requires test split in data config)
testing_params = TestingParams(
    output_type=OutputType.DECODED,
    devices=[0],
    local_save_location=Path("./predictions.tsv"),
    metrics=[MetricParams(alias="auroc", metric_name="AUROC")],
)
module.test(testing_params)

# 5. Predict
module.predict(testing_params)

See Also