Skip to content

Binary Classifier

When to use

Predict yes/no outcomes — will the customer churn, convert, or commit fraud?

For more use cases and complete solutions, see Binary Classification Recipes.

Training script

Python
import numpy as np
from pathlib import Path

from monad.ui.module import BinaryClassificationTask
from monad.ui.config import TrainingParams
from monad.ui.module import load_from_foundation_model

The only task-specific import is BinaryClassificationTask. The remaining imports (numpy, Path, TrainingParams, load_from_foundation_model) are common to every training script.

Your target function must return a single-element np.float32 array containing 0 (negative) or 1 (positive), or None to exclude the entity.

Python
return np.array([1 if churned else 0], dtype=np.float32)

See Target Function for the full data-access API and Target Examples → Binary for ready-made patterns.

Every training script requires two configuration objects: a task that defines the prediction type, and TrainingParams that control metrics and training behavior.

Task declaration

Python
task = BinaryClassificationTask()

BinaryClassificationTask has no required constructor parameters.

Training parameters

Configure training with TrainingParams. At minimum, provide the checkpoint directory:

Python
training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
)

For all available options and their defaults, see Training Parameters. The default metrics for this task are:

Metric Alias Monitoring
AUROC(task="binary", average=None) val_auroc_0 Maximize
AveragePrecision(task="binary", average=None) val_average_precision_0

To add or replace metrics, see Custom Metrics.

Python
trainer = load_from_foundation_model(
    checkpoint_path="./foundation_model",
    downstream_task=task,
    target_fn=my_target_fn,
    training_params=TrainingParams(...),
)
trainer.fit()

See Scenario Model reference for all load_from_foundation_model options.

Full example

Complete training script from the onboarding package — adapt the paths, event table, and target window to your data:

Python
from datetime import timedelta
from pathlib import Path
from typing import Dict

import numpy as np

from monad.batch import SPLIT_TIMESTAMP
from monad.ui.config import TrainingParams
from monad.ui.module import BinaryClassificationTask, load_from_foundation_model
from monad.ui.target_function import Attributes, Events, has_incomplete_training_window


# --- Names & Paths -----------------------------------------------------------

# EDIT: provide path to project directory, PARENT to /fm, /features, /lightning_checkpoints etc.
project_dir = Path("/basemodel/projects/project_dir").resolve()
# EDIT: define name for scenario checkpoints directory; the script will put it under the same parent directory as fm
scenario_name = "scenario_name"

# creating the relative paths
foundation_model_path = project_dir / "fm"
scenario_model_path = project_dir / "scenarios" / scenario_name


# --- Target Function ---------------------------------------------------------

# churn definition, for reference:
# 1: churned (no events in the future window)
# 0: not churned (at least one event in the future window)

# EDIT: provide target details
TARGET_EVENT_TABLE = "transactions" # the event data source to base the target logic (here: purchases of products)
TARGET_WINDOW_DAYS = 21 # the length of time to look at into the future (other time units are possible)


def target_fn(history: Events, future: Events, _entity: Attributes, _ctx: Dict) -> np.ndarray | None:

    # filters out users with too short remaining window
    if has_incomplete_training_window(_ctx, timedelta(days=TARGET_WINDOW_DAYS)):
        return None

    # filters out users with no history
    if history[TARGET_EVENT_TABLE].count() == 0:
        return None

    # trims the future to desired window
    future_window = future.interval_from(
        _ctx[SPLIT_TIMESTAMP],
        timedelta(days=TARGET_WINDOW_DAYS),
    )

    # churn label
    y = 0 if future_window[TARGET_EVENT_TABLE].count() > 0 else 1
    return np.array([y], dtype=np.float32)


# --- Training ----------------------------------------------------------------

# EDIT: metaparams - keep default unless experimenting
learning_rate = 0.0001
epochs = 3 # use 1 for smoke test

# EDIT: limited runs - use for smoke test, then comment out here and below
limit_train_batches = 5 
limit_val_batches = 5

# EDIT: parallelised training - comment out to default to a single GPU
strategy = "ddp"
devices = [0, 1] # list GPU indices

# For more options refer to docs: https://docs.basemodel.ai/reference/trainingparams

task = BinaryClassificationTask()

training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
    learning_rate=learning_rate,
    epochs=epochs,
    devices=devices,
    strategy=strategy,
    limit_train_batches=limit_train_batches, # smoke test, comment out for full runs
    limit_val_batches=limit_val_batches, # smoke test, comment out for full runs
)

trainer = load_from_foundation_model(
    checkpoint_path=foundation_model_path,
    downstream_task=task,
    target_fn=target_fn,
)

trainer.fit(training_params=training_params, overwrite=True) # replace the last param with resume=True for resumed training

This script is part of the onboarding package shipped with every BaseModel installation.