Skip to content

Recommendation

When to use

Predict ranked item lists — which products should be recommended, which content should be shown?

For more use cases and complete solutions, see Recommendation Recipes.

Training script

Python
import numpy as np
from pathlib import Path

from monad.ui.module import RecommendationTask, OneHotRecommendationTask
from monad.ui.target_function import sketch, sequential_decay
from monad.ui.config import TrainingParams
from monad.ui.module import load_from_foundation_model

The task-specific imports are RecommendationTask (or OneHotRecommendationTask) and the sketch helpers (sketch, sequential_decay or time_decay). The remaining imports (numpy, Path, TrainingParams, load_from_foundation_model) are common to every training script.

Recommendation targets return a Sketch — a weighted set of items.

Python
items = future["transactions"]["product_id"]
weights = sequential_decay(future["transactions"], gamma=0.5)
return sketch(items, weights)

Return None to exclude the entity.

Two decay helpers control how item importance varies over time:

Function Behavior
sequential_decay(events, gamma) Weight by position. gamma=0 → first basket only; gamma=1 → equal weights.
time_decay(events, daily_decay) Weight by elapsed time. daily_decay=0.1 → weights drop to 10 % after one day.

See Target Function → Sketches for the full sketch API and Target Examples → Recommendation for ready-made patterns.

Every training script requires two configuration objects: a task that defines the prediction type, and TrainingParams that control metrics and training behavior.

Task declaration

BaseModel offers two recommendation task classes:

Python
# Default — use this in most cases (products, articles, content)
task = RecommendationTask()

# Specialty variant — for low-cardinality item sets
task = OneHotRecommendationTask()
Variant When to use
RecommendationTask Default. High-cardinality catalogs — product IDs, article IDs, content IDs. Typical for retail and e-commerce.
OneHotRecommendationTask Special case. Low-cardinality item sets — brands, categories, departments. Useful in industries with small fixed catalogs or specialty retail with curated assortments. Supports loss masking.

Neither variant requires constructor parameters.

Training parameters

Configure training with TrainingParams. At minimum, provide the checkpoint directory:

Python
training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
)

For all available options and their defaults, see Training Parameters. The default metrics for this task are:

Metric Alias Monitoring
HitRateAtK(top_k=12) val_HR@10_0 Maximize
MeanAveragePrecisionAtK(top_k=12)
PrecisionAtK(top_k=12)

To add or replace metrics, see Custom Metrics.

Python
trainer = load_from_foundation_model(
    checkpoint_path="./foundation_model",
    downstream_task=task,
    target_fn=my_target_fn,
    training_params=TrainingParams(...),
)
trainer.fit()

You can also pass filtering functions to narrow or exclude items from predictions at loading time.

See Scenario Model reference for all load_from_foundation_model options.

Filtering predictions

You can narrow or exclude items from predictions at the entity level by passing filtering functions to load_from_foundation_model:

Python
trainer = load_from_foundation_model(
    checkpoint_path="./foundation_model",
    downstream_task=RecommendationTask(),
    target_fn=product_reco_target,
    predictions_to_exclude_fn=lambda events, attrs, ctx: (
        list(events["transactions"]["product_id"].events)  # exclude already purchased
    ),
)
Parameter Description
predictions_to_include_fn Return items predictions should be narrowed to.
predictions_to_exclude_fn Return items to remove from predictions.

The two parameters are mutually exclusive. For more loading-time overrides, see Loading Overrides.

Full example

Complete training script from the onboarding package — adapt the paths, item column, and decay parameter to your data:

Python
from pathlib import Path
from typing import Dict

from monad.ui.config import TrainingParams
from monad.ui.module import RecommendationTask, load_from_foundation_model
from monad.ui.target_function import Attributes, Events, Sketch, sequential_decay, sketch


# --- Names & Paths -----------------------------------------------------------

# EDIT: provide path to project directory, PARENT to /fm, /features, /lightning_checkpoints etc.
project_dir = Path("/basemodel/projects/project_dir").resolve()
# EDIT: define name for scenario checkpoints directory; the script will put it under the same parent directory as fm
scenario_name = "scenario_name"

# creating the relative paths
foundation_model_path = project_dir / "fm"
scenario_model_path = project_dir / "scenarios" / scenario_name


# --- Target Definition -------------------------------------------------------

# recommendation definition, for reference:
# the target is a weighted sketch of items appearing in future events
# weights can apply temporal decay so nearer events matter more
# the model predicts a ranked list of items

# EDIT: target details
TARGET_EVENT_TABLE = "transactions"  # event data source used to build the target
TARGET_ITEM = "article_id"  # column with item ids to recommend

# EDIT: temporal decay multiplier
# 0: only the next basket is considered
# 1: all future baskets count equally
# values between 0 and 1 apply exponential decay over time
GAMMA = 0


def target_fn(_history: Events, future: Events, _entity: Attributes, _ctx: Dict) -> Sketch | None:

    future_transactions = future[TARGET_EVENT_TABLE]
    items = future_transactions[TARGET_ITEM]

    # recommendation target: weighted sketch of future items
    weights = sequential_decay(future_transactions, gamma=GAMMA)
    y = sketch(items, weights)

    return y


# --- Training ----------------------------------------------------------------

# EDIT: metaparams - keep default unless experimenting
learning_rate = 3e-5
epochs = 3  # use 1 for smoke test

# EDIT: limited runs - use for smoke test, then comment out here and below
limit_train_batches = 5
limit_val_batches = 5

# EDIT: parallelised training - comment out to default to a single GPU
strategy = "ddp"
devices = [0, 1] # list GPU indices

# For more options refer to docs: https://docs.basemodel.ai/reference/trainingparams

task = RecommendationTask()

training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
    learning_rate=learning_rate,
    epochs=epochs,
    devices=devices,
    strategy=strategy,
    limit_train_batches=limit_train_batches,  # smoke test, comment out for full runs
    limit_val_batches=limit_val_batches,  # smoke test, comment out for full runs
)

trainer = load_from_foundation_model(
    checkpoint_path=foundation_model_path,
    downstream_task=task,
    target_fn=target_fn,
)

trainer.fit(training_params=training_params, overwrite=True)  # replace with resume=True for resumed training

This script is part of the onboarding package shipped with every BaseModel installation.

Advanced patterns

Train vs Eval Mode

When you need different labeling logic during training and evaluation — for example, targeting only the next basket during training but scoring against all future purchases during evaluation:

Python
from monad.batch import MODE
from monad.config import DataMode

def target_fn(history, future, attributes, ctx):
    gamma = 0 if ctx[MODE] == DataMode.TRAIN else 1
    weights = sequential_decay(future["transactions"], gamma=gamma)
    return sketch(future["transactions"]["product_id"], weights)

Filtering Masks

Exclude items already seen in history from the training loss. This is useful for acquisition-focused recommendations where you only want to recommend new items. Loss masking applies to OneHotRecommendationTask only.

Python
from monad.ui.target_function import sketch_filtering_mask

mask = sketch_filtering_mask(history["transactions"]["product_id"])
return (sketch(items, weights), mask)

Sketch Arithmetic

Combine sketches from different data sources or time windows into a single recommendation target, or scale weights to adjust the relative importance of item groups:

Python
combined = sketch_a + sketch_b   # merge sketches
scaled = sketch_a * 2.0          # scale weights

This is useful when purchase data is sparse and you want to mix in weaker signals like product views. See the hybrid sketch example in Target Examples.