Skip to content

Multilabel Classifier

When to use

Predict multiple labels that can apply simultaneously — which departments will the customer shop in, which product categories will they engage with?

For more use cases and complete solutions, see Multilabel Classification Recipes.

Multilabel or Multi-Class?

Aspect Multi-Class Multilabel
Labels per entity Exactly 1 0 to many
Target return Normalized scores (sum to 1) Independent 0/1 flags
Example question "Which one category?" "Which categories (plural)?"
Output Softmax probabilities Independent sigmoid probabilities

Training script

Python
import numpy as np
from pathlib import Path

from monad.ui.module import MultilabelClassificationTask
from monad.ui.config import TrainingParams
from monad.ui.module import load_from_foundation_model

The only task-specific import is MultilabelClassificationTask. The remaining imports (numpy, Path, TrainingParams, load_from_foundation_model) are common to every training script.

Return a np.float32 array of independent 0/1 flags with shape=(num_labels,), or None to exclude the entity.

Python
flags, _ = future_txns.groupBy("department").exists(groups=TARGET_NAMES)
return flags  # shape matches len(TARGET_NAMES), values are 0.0 or 1.0

Unlike multi-class, the values do not need to sum to 1 — each label is an independent binary decision.

See Target Function for the full data-access API and Target Examples → Multilabel for ready-made patterns.

Every training script requires two configuration objects: a task that defines the prediction type, and TrainingParams that control metrics and training behavior.

Task declaration

Python
TARGET_NAMES = ["Electronics", "Fashion", "Home", "Sports", "Beauty"]

task = MultilabelClassificationTask(class_names=TARGET_NAMES)
Parameter Required Description
class_names Yes List of label names. Determines output size and label ordering.

Training parameters

Configure training with TrainingParams. At minimum, provide the checkpoint directory:

Python
training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
)

For all available options and their defaults, see Training Parameters. The default metrics for this task are:

Metric Alias Monitoring
AUROC(task="multilabel", average=None) val_auroc_0 Maximize
AveragePrecision(task="multilabel", average=None) val_average_precision_0

To add or replace metrics, see Custom Metrics.

Python
trainer = load_from_foundation_model(
    checkpoint_path="./foundation_model",
    downstream_task=task,
    target_fn=my_target_fn,
    training_params=TrainingParams(...),
)
trainer.fit()

See Scenario Model reference for all load_from_foundation_model options.

Full example

Complete training script from the onboarding package — adapt the paths, target names, and joined column to your data:

Python
from datetime import timedelta
from pathlib import Path
from typing import Dict

import numpy as np

from monad.batch import SPLIT_TIMESTAMP
from monad.ui.config import TrainingParams
from monad.ui.module import MultilabelClassificationTask, load_from_foundation_model
from monad.ui.target_function import Attributes, Events, has_incomplete_training_window
from monad.utils.sql import get_qualified_column_name


# --- Names & Paths -----------------------------------------------------------

# EDIT: provide path to project directory, PARENT to /fm, /features, /lightning_checkpoints etc.
project_dir = Path("/basemodel/projects/project_dir").resolve()
# EDIT: define name for scenario checkpoints directory; the script will put it under the same parent directory as fm
scenario_name = "scenario_name"

# creating the relative paths
foundation_model_path = project_dir / "fm"
scenario_model_path = project_dir / "scenarios" / scenario_name


# --- Target Definition -------------------------------------------------------

# multilabel definition, for reference:
# each label corresponds to one entry in the TARGET_NAMES list
# labels are binary indicators of whether TARGET_ENTITY values appear in the future window
# the model predicts all labels that apply

# EDIT: set output labels of the multilabel classifier (one output per label)
TARGET_NAMES = [
    "Denim Trousers",
    "Swimwear",
    "Trousers",
    "Jersey Basic",
    "Ladies Sport Bottoms",
    "Basic 1",
    "Jersey fancy",
    "Blouse",
    "Shorts",
    "Trouser",
    "Ladies Sport Bras",
    "Casual Lingerie",
    "Expressive Lingerie",
    "Dress",
    "Dresses",
    "Tops Knitwear",
    "Skirt",
    "Nightwear",
    "Knitwear",
]

# EDIT: event data source relevant to the target (here: purchases of products)
TARGET_EVENT_TABLE = "transactions"

# EDIT: column used to match against TARGET_NAMES
# for joined columns, use get_qualified_column_name("col", ["joined_table"])
TARGET_ENTITY = get_qualified_column_name("department_name", ["articles"])

# EDIT: future window length (here: days) used to generate labels
TARGET_WINDOW_DAYS = 21


def target_fn(_history: Events, future: Events, _entity: Attributes, _ctx: Dict) -> np.ndarray | None:

    # filters out users with too short remaining window
    if has_incomplete_training_window(_ctx, timedelta(days=TARGET_WINDOW_DAYS)):
        return None

    # trims the future to desired window
    future_window = future.interval_from(
        _ctx[SPLIT_TIMESTAMP],
        timedelta(days=TARGET_WINDOW_DAYS),
    )

    # for each label: did it appear at least once in the future window?
    y, _ = (
        future_window[TARGET_EVENT_TABLE]
        .groupBy(TARGET_ENTITY)
        .exists(groups=TARGET_NAMES)
    )

    # skips entities with no positive labels in the future window
    if y.sum() == 0:
        return None

    return np.array(y, dtype=np.float32)


# --- Training ----------------------------------------------------------------

# EDIT: metaparams - keep default unless experimenting
learning_rate = 0.0001
epochs = 3  # use 1 for smoke test

# EDIT: limited runs - use for smoke test, then comment out here and below
limit_train_batches = 5
limit_val_batches = 5

# EDIT: parallelised training - comment out to default to a single GPU
strategy = "ddp"
devices = [0, 1] # list GPU indices

# For more options refer to docs: https://docs.basemodel.ai/reference/trainingparams

task = MultilabelClassificationTask(class_names=TARGET_NAMES)

training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
    learning_rate=learning_rate,
    epochs=epochs,
    devices=devices,
    strategy=strategy,
    limit_train_batches=limit_train_batches,  # smoke test, comment out for full runs
    limit_val_batches=limit_val_batches,  # smoke test, comment out for full runs
)

trainer = load_from_foundation_model(
    checkpoint_path=foundation_model_path,
    downstream_task=task,
    target_fn=target_fn,
)

trainer.fit(training_params=training_params, overwrite=True)  # replace with resume=True for resumed training

This script is part of the onboarding package shipped with every BaseModel installation.