Skip to content

Multi-Class Classifier

When to use

Predict one category from a fixed set — which product department will the customer buy from, which segment do they belong to?

For more use cases and complete solutions, see Multiclass Classification Recipes.

Multi-Class or Multilabel?

Aspect Multi-Class Multilabel
Labels per entity Exactly 1 0 to many
Target return Normalized scores (sum to 1) Independent 0/1 flags
Example question "Which one category?" "Which categories (plural)?"
Output Softmax probabilities Independent sigmoid probabilities

Training script

Python
import numpy as np
from pathlib import Path

from monad.ui.module import MulticlassClassificationTask
from monad.ui.config import TrainingParams
from monad.ui.module import load_from_foundation_model

The only task-specific import is MulticlassClassificationTask. The remaining imports (numpy, Path, TrainingParams, load_from_foundation_model) are common to every training script.

For targets based on joined tables, also import get_qualified_column_name:

Python
from monad.utils.sql import get_qualified_column_name

Return a np.float32 array of normalized scores with shape=(num_classes,) summing to 1, or None to exclude the entity.

Python
counts, _ = future_txns.groupBy("category").count(normalize=True, groups=CLASS_NAMES)
return counts  # shape matches len(CLASS_NAMES), sums to 1

For columns from joined tables, use get_qualified_column_name to resolve the column name:

Python
TARGET_ENTITY = get_qualified_column_name("colour_name", ["articles"])

See Target Function for the full data-access API and Target Examples → Multi-Class for ready-made patterns.

Every training script requires two configuration objects: a task that defines the prediction type, and TrainingParams that control metrics and training behavior.

Task declaration

Python
CLASS_NAMES = ["Electronics", "Fashion", "Home", "Sports", "Beauty"]

task = MulticlassClassificationTask(class_names=CLASS_NAMES)
Parameter Required Description
class_names Yes List of category labels. Determines output size and label ordering.

Training parameters

Configure training with TrainingParams. At minimum, provide the checkpoint directory:

Python
training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
)

For all available options and their defaults, see Training Parameters. The default metrics for this task are:

Metric Alias Monitoring
MultipleTargetsRecall val_multiple_targets_recall Maximize
MultipleTargetsRecallPerClass per-class recall

To add or replace metrics, see Custom Metrics.

Python
trainer = load_from_foundation_model(
    checkpoint_path="./foundation_model",
    downstream_task=task,
    target_fn=my_target_fn,
    training_params=TrainingParams(...),
)
trainer.fit()

See Scenario Model reference for all load_from_foundation_model options.

Full example

Complete training script from the onboarding package — adapt the paths, target names, and joined column to your data:

Python
from datetime import timedelta
from pathlib import Path
from typing import Dict

import numpy as np

from monad.batch import SPLIT_TIMESTAMP
from monad.ui.config import TrainingParams
from monad.ui.module import MulticlassClassificationTask, load_from_foundation_model
from monad.ui.target_function import Attributes, Events, has_incomplete_training_window
from monad.utils.sql import get_qualified_column_name


# --- Names & Paths -----------------------------------------------------------

# EDIT: provide path to project directory, PARENT to /fm, /features, /lightning_checkpoints etc.
project_dir = Path("/basemodel/projects/project_dir").resolve()
# EDIT: define name for scenario checkpoints directory; the script will put it under the same parent directory as fm
scenario_name = "scenario_name"

# creating the relative paths
foundation_model_path = project_dir / "fm"
scenario_model_path = project_dir / "scenarios" / scenario_name


# --- Target Definition -------------------------------------------------------

# multiclass definition, for reference:
# each class corresponds to one entry in the TARGET_NAMES list
# labels are normalized counts of TARGET_ENTITY (= target column) values in the future window
# the model predicts the dominant class

# EDIT: output classes of the multiclass classifier (one class predicted per entity)
TARGET_NAMES = [
    "Black",
    "Blue",
    "White",
    "Pink",
    "Grey",
    "Red",
    "Beige",
    "Green",
    "Khaki green",
    "Yellow",
    "Orange",
    "Brown",
    "Metal",
    "Turquoise",
    "Mole",
    "Lilac Purple",
]

# EDIT: event data source relevant to the target (here: purchases of products)
TARGET_EVENT_TABLE = "transactions"

# EDIT: column used to match against TARGET_NAMES
# for joined columns, use get_qualified_column_name("col", ["joined_table"])
TARGET_ENTITY = get_qualified_column_name("perceived_colour_master_name", ["articles"])

# EDIT: future window length (here: days) used to generate labels
TARGET_WINDOW_DAYS = 21


def target_fn(_history: Events, future: Events, _entity: Attributes, _ctx: Dict) -> np.ndarray | None:

    # filters out users with too short remaining window
    if has_incomplete_training_window(_ctx, timedelta(days=TARGET_WINDOW_DAYS)):
        return None

    # trims the future to desired window
    future_window = future.interval_from(
        _ctx[SPLIT_TIMESTAMP],
        timedelta(days=TARGET_WINDOW_DAYS),
    )

    # multiclass target as a distribution over classes (normalized counts)
    y, _ = (
        future_window[TARGET_EVENT_TABLE]
        .groupBy(TARGET_ENTITY)
        .count(normalize=True, groups=TARGET_NAMES)
    )

    # skips entities with no signal in the future window
    if y.sum() == 0:
        return None

    return np.array(y, dtype=np.float32)


# --- Training ----------------------------------------------------------------

# EDIT: metaparams - keep default unless experimenting
learning_rate = 0.0001
epochs = 3  # use 1 for smoke test

# EDIT: limited runs - use for smoke test, then comment out here and below
limit_train_batches = 5
limit_val_batches = 5

# EDIT: parallelised training - comment out to default to a single GPU
strategy = "ddp"
devices = [0, 1] # list GPU indices

# For more options refer to docs: https://docs.basemodel.ai/reference/trainingparams

task = MulticlassClassificationTask(class_names=TARGET_NAMES)

training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
    learning_rate=learning_rate,
    epochs=epochs,
    devices=devices,
    strategy=strategy,
    limit_train_batches=limit_train_batches,  # smoke test, comment out for full runs
    limit_val_batches=limit_val_batches,  # smoke test, comment out for full runs
)

trainer = load_from_foundation_model(
    checkpoint_path=foundation_model_path,
    downstream_task=task,
    target_fn=target_fn,
)

trainer.fit(training_params=training_params, overwrite=True)  # replace with resume=True for resumed training

This script is part of the onboarding package shipped with every BaseModel installation.