Customizing testing metrics
In this article, we show how to set alternative metrics during scenario model testing.
The table below contains default metrics for each downstream task.
Task | Metric |
---|---|
Binary classification |
|
Multiclass Classification |
|
Multi-label classification |
|
Regression |
|
Recommendations |
|
To define custom metrics you should use metrics
parameter from theTestingParams
class.
Metrics should be passed as a list of MetricParams
objects imported from monad.config
module.
Parameters |
---|
- alias: str
No default Name assigned by the user to identify metric. - metric_name: str
Default different for ech task Name of the metric from BaseModel Metrics ortorchmetrics
(list available here). - kwargs: dict[str, Any]
Default: dict Arguments used to initialize metric, the same ones that should be used when initializing metric.
Example |
---|
The example provided below demonstrates how to use recall and precision during testing binary classification model.
from monad.config import DataTimeSplit, TimeRange
from monad.ui.config import MetricParams, OutputType, TestingParams, DataMode
from monad.ui.module import load_from_checkpoint
from datetime import datetime
# declare variables
checkpoint_path = "<path/to/downstream/model/checkpoints>" # location of scenario model checkpoints
save_path = "<path/to/predictions/predictions_and_ground_truth.tsv>" # location to store evaluation results
test_start_date = datetime(2023, 8, 1) # first day of test period
test_end_date = datetime(2023, 8, 22) # last day of test period
# load scenario model to instantiate testing module
testing_module = load_from_checkpoint(
checkpoint_path = checkpoint_path,
split={DataMode.TEST: TimeRange(start_date=test_start_date, end_date=test_end_date)}
)
# define testing parameters
testing_params = TestingParams(
local_save_location = save_path,
output_type = OutputType.DECODED,
metrics=[
MetricParams(alias="recall", metric_name= "Recall", kwargs={"task":'binary'}),
MetricParams(alias="precision", metric_name= "Precision", kwargs={"task":'binary'})
],
)
# run evaluation
testing_module.test(testing_params = testing_params)
Updated 9 days ago