Multi-Class Classifier
When to use
Predict one category from a fixed set — which product department will the customer buy from, which segment do they belong to?
For more use cases and complete solutions, see Multiclass Classification Recipes.
Multi-Class or Multilabel?
| Aspect | Multi-Class | Multilabel |
|---|---|---|
| Labels per entity | Exactly 1 | 0 to many |
| Target return | Normalized scores (sum to 1) | Independent 0/1 flags |
| Example question | "Which one category?" | "Which categories (plural)?" |
| Output | Softmax probabilities | Independent sigmoid probabilities |
Training script
import numpy as np
from pathlib import Path
from monad.ui.module import MulticlassClassificationTask
from monad.ui.config import TrainingParams
from monad.ui.module import load_from_foundation_model
The only task-specific import is MulticlassClassificationTask. The remaining imports (numpy, Path, TrainingParams, load_from_foundation_model) are common to every training script.
For targets based on joined tables, also import get_qualified_column_name:
Return a np.float32 array of normalized scores with shape=(num_classes,) summing to 1, or None to exclude the entity.
counts, _ = future_txns.groupBy("category").count(normalize=True, groups=CLASS_NAMES)
return counts # shape matches len(CLASS_NAMES), sums to 1
For columns from joined tables, use get_qualified_column_name to resolve the column name:
See Target Function for the full data-access API and Target Examples → Multi-Class for ready-made patterns.
Every training script requires two configuration objects: a task that defines the prediction type, and TrainingParams that control metrics and training behavior.
Task declaration
CLASS_NAMES = ["Electronics", "Fashion", "Home", "Sports", "Beauty"]
task = MulticlassClassificationTask(class_names=CLASS_NAMES)
| Parameter | Required | Description |
|---|---|---|
class_names |
Yes | List of category labels. Determines output size and label ordering. |
Training parameters
Configure training with TrainingParams. At minimum, provide the checkpoint directory:
For all available options and their defaults, see Training Parameters. The default metrics for this task are:
| Metric | Alias | Monitoring |
|---|---|---|
MultipleTargetsRecall |
val_multiple_targets_recall |
Maximize |
MultipleTargetsRecallPerClass |
per-class recall | — |
To add or replace metrics, see Custom Metrics.
trainer = load_from_foundation_model(
checkpoint_path="./foundation_model",
downstream_task=task,
target_fn=my_target_fn,
training_params=TrainingParams(...),
)
trainer.fit()
See Scenario Model reference for all load_from_foundation_model options.
Full example
Complete training script from the onboarding package — adapt the paths, target names, and joined column to your data:
from datetime import timedelta
from pathlib import Path
from typing import Dict
import numpy as np
from monad.batch import SPLIT_TIMESTAMP
from monad.ui.config import TrainingParams
from monad.ui.module import MulticlassClassificationTask, load_from_foundation_model
from monad.ui.target_function import Attributes, Events, has_incomplete_training_window
from monad.utils.sql import get_qualified_column_name
# --- Names & Paths -----------------------------------------------------------
# EDIT: provide path to project directory, PARENT to /fm, /features, /lightning_checkpoints etc.
project_dir = Path("/basemodel/projects/project_dir").resolve()
# EDIT: define name for scenario checkpoints directory; the script will put it under the same parent directory as fm
scenario_name = "scenario_name"
# creating the relative paths
foundation_model_path = project_dir / "fm"
scenario_model_path = project_dir / "scenarios" / scenario_name
# --- Target Definition -------------------------------------------------------
# multiclass definition, for reference:
# each class corresponds to one entry in the TARGET_NAMES list
# labels are normalized counts of TARGET_ENTITY (= target column) values in the future window
# the model predicts the dominant class
# EDIT: output classes of the multiclass classifier (one class predicted per entity)
TARGET_NAMES = [
"Black",
"Blue",
"White",
"Pink",
"Grey",
"Red",
"Beige",
"Green",
"Khaki green",
"Yellow",
"Orange",
"Brown",
"Metal",
"Turquoise",
"Mole",
"Lilac Purple",
]
# EDIT: event data source relevant to the target (here: purchases of products)
TARGET_EVENT_TABLE = "transactions"
# EDIT: column used to match against TARGET_NAMES
# for joined columns, use get_qualified_column_name("col", ["joined_table"])
TARGET_ENTITY = get_qualified_column_name("perceived_colour_master_name", ["articles"])
# EDIT: future window length (here: days) used to generate labels
TARGET_WINDOW_DAYS = 21
def target_fn(_history: Events, future: Events, _entity: Attributes, _ctx: Dict) -> np.ndarray | None:
# filters out users with too short remaining window
if has_incomplete_training_window(_ctx, timedelta(days=TARGET_WINDOW_DAYS)):
return None
# trims the future to desired window
future_window = future.interval_from(
_ctx[SPLIT_TIMESTAMP],
timedelta(days=TARGET_WINDOW_DAYS),
)
# multiclass target as a distribution over classes (normalized counts)
y, _ = (
future_window[TARGET_EVENT_TABLE]
.groupBy(TARGET_ENTITY)
.count(normalize=True, groups=TARGET_NAMES)
)
# skips entities with no signal in the future window
if y.sum() == 0:
return None
return np.array(y, dtype=np.float32)
# --- Training ----------------------------------------------------------------
# EDIT: metaparams - keep default unless experimenting
learning_rate = 0.0001
epochs = 3 # use 1 for smoke test
# EDIT: limited runs - use for smoke test, then comment out here and below
limit_train_batches = 5
limit_val_batches = 5
# EDIT: parallelised training - comment out to default to a single GPU
strategy = "ddp"
devices = [0, 1] # list GPU indices
# For more options refer to docs: https://docs.basemodel.ai/reference/trainingparams
task = MulticlassClassificationTask(class_names=TARGET_NAMES)
training_params = TrainingParams(
checkpoint_dir=scenario_model_path,
learning_rate=learning_rate,
epochs=epochs,
devices=devices,
strategy=strategy,
limit_train_batches=limit_train_batches, # smoke test, comment out for full runs
limit_val_batches=limit_val_batches, # smoke test, comment out for full runs
)
trainer = load_from_foundation_model(
checkpoint_path=foundation_model_path,
downstream_task=task,
target_fn=target_fn,
)
trainer.fit(training_params=training_params, overwrite=True) # replace with resume=True for resumed training
This script is part of the onboarding package shipped with every BaseModel installation.