Skip to content

Training Parameters

Scenario models use the same TrainingParams class as the foundation model, but with scenario-specific behavior. For the complete field-level reference, see Foundation Model Training Parameters.

from monad.config import TrainingParams, MetricParams, EarlyStopping

Key Differences from Foundation Model

Feature Foundation Model Scenario Model
Custom metrics Not supported Supported via metrics parameter
checkpoint_dir Set in YAML Commonly set in TrainingParams
metric_to_monitor Not used Select best model by custom metric
Typical epochs 1 3-10+

Commonly Used Parameters

training_params = TrainingParams(
    epochs=5,
    learning_rate=1e-4,
    devices=[0],
    checkpoint_dir=Path("./my_scenario_model"),
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC"),
    ],
    metric_to_monitor="auroc",
    metric_monitoring_mode="max",
    early_stopping=EarlyStopping(patience=5, min_delta=0.001),
)

module.fit(training_params)

Available Metrics by Task Type

Binary Classification

Metric Name Description
AUROC Area Under the ROC Curve.
AveragePrecision Area Under the Precision-Recall Curve.
Recall True positive rate.
Precision Positive predictive value.
F1Score Harmonic mean of precision and recall.
Accuracy Overall correctness.
metrics = [
    MetricParams(alias="auroc", metric_name="AUROC"),
    MetricParams(alias="avg_precision", metric_name="AveragePrecision"),
]

Multiple Targets Classification

Custom metrics from monad.metrics for tasks with multiple valid targets:

Metric Name Description Kwargs
MultipleTargetsRecall Recall across multiple valid targets.
MultipleTargetsRecallPerClass Per-class recall for multiple targets. {"num_classes": N}

Multiclass Classification

Metric Name Description Kwargs
Accuracy Overall correctness. {"task": "multiclass", "num_classes": N}
F1Score F1 with macro averaging. {"task": "multiclass", "num_classes": N, "average": "macro"}
metrics = [
    MetricParams(
        alias="accuracy",
        metric_name="Accuracy",
        kwargs={"task": "multiclass", "num_classes": 3},
    ),
]

Multilabel Classification

Metric Name Description Kwargs
F1Score F1 with micro averaging. {"task": "multilabel", "num_labels": N, "average": "micro"}
HammingDistance Fraction of incorrect labels. {"task": "multilabel", "num_labels": N}

Regression

RegressionTask requires num_targets (number of regression outputs) and optionally max_value (upper bound for target normalization).

from monad.ui.module import RegressionTask

task = RegressionTask(num_targets=1, max_value=1000.0)
Metric Name Description
MeanSquaredError Mean Squared Error (use for RMSE via post-processing).
MeanAbsoluteError Mean Absolute Error.
R2Score Coefficient of determination.
metrics = [
    MetricParams(alias="mse", metric_name="MeanSquaredError"),
    MetricParams(alias="mae", metric_name="MeanAbsoluteError"),
]

Recommendation

Two task types are available:

  • RecommendationTasksketch-based recommendation for high-cardinality item catalogs.
  • OneHotRecommendationTask — one-hot recommendation for low-cardinality item sets.
Metric Name Description Kwargs
HitRateAtK Hit rate at top-k. {"k": 10}
MeanAveragePrecisionAtK MAP at top-k. {"k": 10}
NDCGAtK Normalized Discounted Cumulative Gain at top-k. {"k": 10}
PrecisionAtK Precision at top-k. {"k": 10}
MeanReciprocalRank Mean reciprocal rank of the first relevant item. {"k": 100} (default)
metrics = [
    MetricParams(alias="hr10", metric_name="HitRateAtK", kwargs={"k": 10}),
    MetricParams(alias="ndcg10", metric_name="NDCGAtK", kwargs={"k": 10}),
]

Tip

Metrics are resolved by searching monad.metrics first, then torchmetrics. You can use any metric from either library.

MetricParams

Defines a metric by name from monad.metrics or torchmetrics.

Parameter Type Default Description
alias str required Unique alias to identify the metric. Used in logs and for metric_to_monitor.
metric_name str required Name of the metric class from monad.metrics or torchmetrics.
kwargs dict[str, Any] {} Arguments passed to the metric constructor.

CustomMetric

Allows passing a pre-instantiated torchmetrics.Metric object directly, instead of resolving by name via MetricParams.

from monad.config import CustomMetric
from torchmetrics.classification import BinaryAUROC

metrics = [
    CustomMetric(alias="my_auroc", metric=BinaryAUROC()),
]
Parameter Type Default Description
alias str required Unique alias to identify the metric. Used in logs and for metric_to_monitor.
metric torchmetrics.Metric required A pre-instantiated torchmetrics.Metric instance.

Multi-GPU Training

For larger datasets, use distributed training strategies:

training_params = TrainingParams(
    epochs=5,
    learning_rate=1e-4,
    devices=[0, 1, 2, 3],
    strategy="ddp",
    precision="bf16-mixed",
    checkpoint_dir=Path("./my_model"),
)

Note

"ddp" and "fsdp" strategies require multiple devices. When using "fsdp:A:B" (custom FSDP), you need at least A * B devices.