Skip to content

Custom Metrics

Each task type ships with sensible default metrics (listed on the individual Model Configuration pages). You can add extra metrics, replace the defaults, or change which metric is monitored for early stopping and checkpointing.

Adding Metrics

Pass a list of MetricParams or CustomMetric objects to TrainingParams.metrics.

MetricParams

Use MetricParams when the metric is available by name in BaseModel's predefined set or in TorchMetrics:

Python
from monad.ui.config import TrainingParams, MetricParams

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=1,
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC", kwargs={"task": "binary", "average": None}),
        MetricParams(alias="recall", metric_name="Recall", kwargs={"task": "binary"}),
    ],
)
Field Description
alias Name you assign to identify the metric in logs
metric_name Name from BaseModel's predefined metrics or TorchMetrics
kwargs Arguments passed to the metric constructor

CustomMetric

Use CustomMetric when you need a fully initialized TorchMetrics instance:

Python
from monad.ui.config import CustomMetric
from torchmetrics.classification import F1Score

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=1,
    metrics=[
        CustomMetric(alias="f1", metric=F1Score(task="binary")),
    ],
)
Field Description
alias Name you assign to identify the metric in logs
metric An initialized torchmetrics.Metric instance

Monitoring a Metric

Control which metric determines the best checkpoint and early stopping:

Python
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=3,
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC", kwargs={"task": "binary", "average": None}),
        MetricParams(alias="recall", metric_name="Recall", kwargs={"task": "binary"}),
    ],
    metric_to_monitor="val_recall_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
)
Parameter Description
metric_to_monitor Validation metric alias (prefixed val_) to track
metric_monitoring_mode MetricMonitoringMode.MAX (higher is better) or MetricMonitoringMode.MIN (lower is better)

Predefined Metrics

These metrics are available by name via MetricParams.metric_name:

Metric name Task Description
MultipleTargetsRecall Multiclass Fraction of correct class predictions
MultipleTargetsRecallPerClass Multiclass Recall per class
PrecisionAtK Recommendation Fraction of top-K items that are relevant
MeanAveragePrecisionAtK Recommendation Average precision within top-K, averaged over entities
HitRateAtK Recommendation Fraction of entities where top-K contains at least one hit
MeanReciprocalRank Recommendation How early the first relevant item appears
NDCGAtK Recommendation Ranking quality for top-K, normalized to [0, 1]

Any metric from TorchMetrics can also be used via either MetricParams or CustomMetric.