Scenario Model API
A scenario model fine-tunes the foundation model for a specific prediction task. Unlike the foundation model (configured via YAML), scenario models are configured entirely via the Python API.
load_from_foundation_model()
Creates a scenario model from a trained foundation model checkpoint.
from pathlib import Path
from monad.ui.module import load_from_foundation_model, BinaryClassificationTask
module = load_from_foundation_model(
checkpoint_path=Path("./foundation_model"),
downstream_task=BinaryClassificationTask(),
target_fn=my_target_fn,
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
checkpoint_path |
str \| Path |
required | Path to the foundation model checkpoint directory. |
downstream_task |
Task |
required | The prediction task type. See Task Types below. |
target_fn |
TargetFunction |
required | Target function defining what to predict. See Target Function Reference. |
with_head |
bool |
False |
Whether to reuse the foundation model head (advanced). |
pl_logger |
Logger \| None |
None |
PyTorch Lightning logger instance. |
worker_init_fn |
Callable[[int], None] \| None |
None |
Custom worker initialization function for data loading. |
predictions_to_include_fn |
PredictionsFilteringFnType \| None |
None |
Function returning items/classes to include in predictions. Mutually exclusive with predictions_to_exclude_fn. |
predictions_to_exclude_fn |
PredictionsFilteringFnType \| None |
None |
Function returning items/classes to exclude from predictions. Mutually exclusive with predictions_to_include_fn. |
split |
TimeSplitOverride \| EntitySplitOverride \| None |
None |
Override the data split from the foundation model checkpoint. |
**kwargs |
Additional keyword arguments forwarded as data parameter overrides (e.g., data_params, query_optimization). See Loading Overrides. |
Note
predictions_to_include_fn and predictions_to_exclude_fn are only supported for RecommendationTask, MulticlassClassificationTask, and MultilabelClassificationTask.
load_from_checkpoint()
Loads a previously trained scenario model for resumed training, testing, or inference.
from pathlib import Path
from monad.ui.module import load_from_checkpoint
module = load_from_checkpoint(
checkpoint_path=Path("./my_scenario_model"),
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
checkpoint_path |
str \| Path |
required | Path to the scenario model checkpoint directory. |
pl_logger |
Logger \| None |
None |
PyTorch Lightning logger instance. |
scoring |
bool |
False |
Whether loading for inference (True) or resumed training (False). |
worker_init_fn |
Callable[[int], None] \| None |
None |
Custom worker initialization function. |
split |
TimeSplitOverride \| EntitySplitOverride \| None |
None |
Override the data split from the checkpoint. |
predictions_to_include_fn |
PredictionsFilteringFnType \| None |
None |
Items/classes to include in predictions. |
predictions_to_exclude_fn |
PredictionsFilteringFnType \| None |
None |
Items/classes to exclude from predictions. |
**kwargs |
Additional keyword arguments forwarded as data parameter overrides. See Loading Overrides. |
Which loading function to use
load_from_foundation_model()— Use when creating a new scenario model from a foundation model checkpoint. Thecheckpoint_pathmust point to the innerfm/subdirectory (e.g.,output_path/fm/), not the outer output directory.load_from_checkpoint()— Use when loading an already-trained scenario model for test, predict, or interpret. Usingload_from_foundation_model()for inference re-initializes the scenario head and discards trained weights.
Task Types
Each scenario model requires a task type that defines the prediction objective.
from monad.ui.module import (
BinaryClassificationTask,
MulticlassClassificationTask,
MultilabelClassificationTask,
RegressionTask,
RecommendationTask,
OneHotRecommendationTask,
)
| Task | Constructor | Description |
|---|---|---|
BinaryClassificationTask() |
No parameters | Predict yes/no outcomes (e.g., churn, conversion). |
MulticlassClassificationTask(class_names=[...]) |
class_names: list[str] |
Predict one of N mutually exclusive classes. |
MultilabelClassificationTask(class_names=[...]) |
class_names: list[str] |
Predict multiple labels simultaneously. |
RegressionTask(num_targets=..., max_value=2000000.0) |
num_targets: int (required), max_value: float |
Predict continuous values (e.g., spend, LTV). |
RecommendationTask() |
No parameters | Rank items for each entity. Default for high-cardinality catalogs. |
OneHotRecommendationTask() |
No parameters | Rank items for each entity. For low-cardinality item sets. |
Workflow Example
Complete training and inference pipeline:
from pathlib import Path
import numpy as np
from datetime import timedelta
from monad.ui.module import load_from_foundation_model, BinaryClassificationTask
from monad.config import TrainingParams, TestingParams, OutputType, MetricParams
from monad.targets import Events, Attributes
from monad.batch import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window
# 1. Define target function
def churn_target_fn(history: Events, future: Events, attributes: Attributes, ctx: dict):
if history["transactions"].count() < 2:
return None
if has_incomplete_training_window(ctx, timedelta(days=30)):
return None
future_30d = future.interval_from(ctx[SPLIT_TIMESTAMP], timedelta(days=30))
churned = 1 if future_30d["transactions"].count() == 0 else 0
return np.array([churned], dtype=np.float32)
# 2. Load foundation model and define task
module = load_from_foundation_model(
checkpoint_path=Path("./foundation_model"),
downstream_task=BinaryClassificationTask(),
target_fn=churn_target_fn,
)
# 3. Train
training_params = TrainingParams(
epochs=5,
learning_rate=1e-4,
devices=[0],
checkpoint_dir=Path("./churn_model"),
metrics=[MetricParams(alias="auroc", metric_name="AUROC")],
)
module.fit(training_params)
# 4. Test on test split (requires test split in data config)
testing_params = TestingParams(
output_type=OutputType.DECODED,
devices=[0],
local_save_location=Path("./predictions.tsv"),
metrics=[MetricParams(alias="auroc", metric_name="AUROC")],
)
module.test(testing_params)
# 5. Predict
module.predict(testing_params)
See Also
- Training Parameters —
TrainingParamsfor scenario models - Target Function Reference — How to write target functions
- Testing Parameters —
TestingParamsandOutputType - Interpretability — Model interpretation and attribution