Training Parameters
Scenario models use the same TrainingParams class as the foundation model, but with scenario-specific behavior. For the complete field-level reference, see Foundation Model Training Parameters.
Key Differences from Foundation Model
| Feature | Foundation Model | Scenario Model |
|---|---|---|
| Custom metrics | Not supported | Supported via metrics parameter |
checkpoint_dir |
Set in YAML | Commonly set in TrainingParams |
metric_to_monitor |
Not used | Select best model by custom metric |
| Typical epochs | 1 | 3-10+ |
Commonly Used Parameters
training_params = TrainingParams(
epochs=5,
learning_rate=1e-4,
devices=[0],
checkpoint_dir=Path("./my_scenario_model"),
metrics=[
MetricParams(alias="auroc", metric_name="AUROC"),
],
metric_to_monitor="auroc",
metric_monitoring_mode="max",
early_stopping=EarlyStopping(patience=5, min_delta=0.001),
)
module.fit(training_params)
Available Metrics by Task Type
Binary Classification
| Metric Name | Description |
|---|---|
AUROC |
Area Under the ROC Curve. |
AveragePrecision |
Area Under the Precision-Recall Curve. |
Recall |
True positive rate. |
Precision |
Positive predictive value. |
F1Score |
Harmonic mean of precision and recall. |
Accuracy |
Overall correctness. |
metrics = [
MetricParams(alias="auroc", metric_name="AUROC"),
MetricParams(alias="avg_precision", metric_name="AveragePrecision"),
]
Multiple Targets Classification
Custom metrics from monad.metrics for tasks with multiple valid targets:
| Metric Name | Description | Kwargs |
|---|---|---|
MultipleTargetsRecall |
Recall across multiple valid targets. | — |
MultipleTargetsRecallPerClass |
Per-class recall for multiple targets. | {"num_classes": N} |
Multiclass Classification
| Metric Name | Description | Kwargs |
|---|---|---|
Accuracy |
Overall correctness. | {"task": "multiclass", "num_classes": N} |
F1Score |
F1 with macro averaging. | {"task": "multiclass", "num_classes": N, "average": "macro"} |
metrics = [
MetricParams(
alias="accuracy",
metric_name="Accuracy",
kwargs={"task": "multiclass", "num_classes": 3},
),
]
Multilabel Classification
| Metric Name | Description | Kwargs |
|---|---|---|
F1Score |
F1 with micro averaging. | {"task": "multilabel", "num_labels": N, "average": "micro"} |
HammingDistance |
Fraction of incorrect labels. | {"task": "multilabel", "num_labels": N} |
Regression
RegressionTask requires num_targets (number of regression outputs) and optionally max_value (upper bound for target normalization).
| Metric Name | Description |
|---|---|
MeanSquaredError |
Mean Squared Error (use for RMSE via post-processing). |
MeanAbsoluteError |
Mean Absolute Error. |
R2Score |
Coefficient of determination. |
metrics = [
MetricParams(alias="mse", metric_name="MeanSquaredError"),
MetricParams(alias="mae", metric_name="MeanAbsoluteError"),
]
Recommendation
Two task types are available:
RecommendationTask— sketch-based recommendation for high-cardinality item catalogs.OneHotRecommendationTask— one-hot recommendation for low-cardinality item sets.
| Metric Name | Description | Kwargs |
|---|---|---|
HitRateAtK |
Hit rate at top-k. | {"k": 10} |
MeanAveragePrecisionAtK |
MAP at top-k. | {"k": 10} |
NDCGAtK |
Normalized Discounted Cumulative Gain at top-k. | {"k": 10} |
PrecisionAtK |
Precision at top-k. | {"k": 10} |
MeanReciprocalRank |
Mean reciprocal rank of the first relevant item. | {"k": 100} (default) |
metrics = [
MetricParams(alias="hr10", metric_name="HitRateAtK", kwargs={"k": 10}),
MetricParams(alias="ndcg10", metric_name="NDCGAtK", kwargs={"k": 10}),
]
Tip
Metrics are resolved by searching monad.metrics first, then torchmetrics. You can use any metric from either library.
MetricParams
Defines a metric by name from monad.metrics or torchmetrics.
| Parameter | Type | Default | Description |
|---|---|---|---|
alias |
str |
required | Unique alias to identify the metric. Used in logs and for metric_to_monitor. |
metric_name |
str |
required | Name of the metric class from monad.metrics or torchmetrics. |
kwargs |
dict[str, Any] |
{} |
Arguments passed to the metric constructor. |
CustomMetric
Allows passing a pre-instantiated torchmetrics.Metric object directly, instead of resolving by name via MetricParams.
from monad.config import CustomMetric
from torchmetrics.classification import BinaryAUROC
metrics = [
CustomMetric(alias="my_auroc", metric=BinaryAUROC()),
]
| Parameter | Type | Default | Description |
|---|---|---|---|
alias |
str |
required | Unique alias to identify the metric. Used in logs and for metric_to_monitor. |
metric |
torchmetrics.Metric |
required | A pre-instantiated torchmetrics.Metric instance. |
Multi-GPU Training
For larger datasets, use distributed training strategies:
training_params = TrainingParams(
epochs=5,
learning_rate=1e-4,
devices=[0, 1, 2, 3],
strategy="ddp",
precision="bf16-mixed",
checkpoint_dir=Path("./my_model"),
)
Note
"ddp" and "fsdp" strategies require multiple devices. When using "fsdp:A:B" (custom FSDP), you need at least A * B devices.