HomeGuidesRecipesChangelog
Guides

Customizing testing metrics

In this article, we show how to set alternative metrics during scenario model testing.

The table below contains default metrics for each downstream task.

Task

Metric

Binary classification

AUROC(num_labels=num_classes, task="binary", average=None), AveragePrecision(num_labels=num_classes, task="binary", average=None)

Multiclass Classification

Precision(num_classes=num_classes, task="multiclass", average=None), Recall(num_classes=num_classes, task="multiclass", average=None)

Multi-label classification

AUROC(num_labels=num_classes, task="multilabel", average=None), AveragePrecision(num_labels=num_classes, task="multilabel", average=None)

Regression

MeanSquaredError(squared=False)

Recommendations

HitRateAtK(k=1), HitRateAtK(k=10), HitRateAtK(k=25), HitRateAtK(k=50), HitRateAtK(k=100), HitRateAtK(k=top_k), MeanAveragePrecisionAtK(k=12), MeanAveragePrecisionAtK(k=top_k), PrecisionAtK(k=10), PrecisionAtK(k=top_k) where top_k defaults to 12.

To define custom metrics you should use metrics parameter from theTestingParams class.
Metrics should be passed as a list of MetricParams objects imported from monad.config module.

Parameters
  • alias: str
    No default Name assigned by the user to identify metric.
  • metric_name: str
    Default different for ech task Name of the metric from BaseModel Metrics or torchmetrics (list available here).
  • kwargs: dict[str, Any]
    Default: dict Arguments used to initialize metric, the same ones that should be used when initializing metric.

Example

The example provided below demonstrates how to use recall and precision during testing binary classification model.

from monad.config import DataTimeSplit, TimeRange
from monad.ui.config import MetricParams, OutputType, TestingParams, DataMode
from monad.ui.module import load_from_checkpoint
from datetime import datetime

# declare variables
checkpoint_path = "<path/to/downstream/model/checkpoints>" # location of scenario model checkpoints
save_path = "<path/to/predictions/predictions_and_ground_truth.tsv>" # location to store evaluation results
test_start_date = datetime(2023, 8, 1) # first day of test period
test_end_date = datetime(2023, 8, 22) # last day of test period

# load scenario model to instantiate testing module
testing_module = load_from_checkpoint(
    checkpoint_path = checkpoint_path,
    split={DataMode.TEST: TimeRange(start_date=test_start_date, end_date=test_end_date)}
)

# define testing parameters
testing_params = TestingParams(
    local_save_location = save_path,
    output_type = OutputType.DECODED,
    metrics=[
        MetricParams(alias="recall", metric_name= "Recall", kwargs={"task":'binary'}),
        MetricParams(alias="precision", metric_name= "Precision", kwargs={"task":'binary'})
  ], 
)

# run evaluation
testing_module.test(testing_params = testing_params)