Next Basket Recommendation
Task type: RecommendationTask
Industry: Retail / E-commerce
This recipe predicts which products a customer will purchase next. It uses the sketch representation — a compact, trainable summary of likely items — to produce ranked product recommendations. The output is ideal for "Recommended for you" carousels, personalized emails, and checkout cross-sells.
What is a sketch? A
Sketchis a lazy representation for recommendation targets. It encodes a weighted set of items that the model learns to predict. At inference time, the model outputs product scores that can be ranked into a top-N recommendation list.
Prerequisites
Before writing a target function you need:
- A trained foundation model built on event data that includes a
transactionsdata source with a product ID column (e.g.,article_id). - The monad library installed in your environment (for Python App).
Target Function
| Argument | Type | Description |
|---|---|---|
history |
Events |
All events before the temporal split. |
future |
Events |
All events after the temporal split. |
attributes |
Attributes |
Static entity attributes. |
ctx |
Dict |
Context dictionary containing SPLIT_TIMESTAMP, data mode, etc. |
For recommendation tasks, the function must return one of:
- A
Sketchobject — a weighted set of target items. None— exclude this customer (e.g., no future transactions).
Full Example
import numpy as np
from datetime import timedelta
from typing import Dict
from monad.ui.target_function import Events, Attributes, Sketch, sketch, sequential_decay
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window
# === Configuration ===
TRANSACTION_DATA_SOURCE = "transactions"
PRODUCT_ID_COLUMN = "article_id"
GAMMA = 0 # Decay factor: 0 = next basket only, higher = more future baskets
def next_basket_target_fn(
history: Events,
future: Events,
attributes: Attributes,
ctx: Dict,
) -> Sketch:
"""Recommend products the customer is likely to buy next."""
# 1. Access future transactions
future_transactions = future[TRANSACTION_DATA_SOURCE]
# 2. Extract target item IDs
article_ids = future_transactions[PRODUCT_ID_COLUMN]
# 3. Compute training weights using sequential decay
training_weights = sequential_decay(future_transactions, gamma=GAMMA)
# 4. Return a sketch
return sketch(article_ids, training_weights)
def next_basket_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> target_function.Sketch:
"""Recommend products the customer is likely to buy next."""
# === Configuration ===
TRANSACTION_DATA_SOURCE = "transactions"
PRODUCT_ID_COLUMN = "article_id"
GAMMA = 0 # Decay factor: 0 = next basket only, higher = more future baskets
# 1. Access future transactions
future_transactions = future[TRANSACTION_DATA_SOURCE]
# 2. Extract target item IDs
article_ids = future_transactions[PRODUCT_ID_COLUMN]
# 3. Compute training weights using sequential decay
training_weights = target_function.sequential_decay(future_transactions, gamma=GAMMA)
# 4. Return a sketch
return target_function.sketch(article_ids, training_weights)
Step-by-Step Breakdown
① Access future transactions
Retrieves all transaction events after the temporal split. Unlike classification recipes, we typically use all future events rather than trimming to a fixed window, since the decay function handles temporal weighting.
② Extract item IDs
Selects the product ID column from the future transactions. This returns a ModalityEvents object that the sketch function accepts.
③ Apply sequential decay
sequential_decay assigns weights to future events based on their temporal order:
gamma=0: Only the next basket (first future transaction) receives weight. All later transactions get weight 0.gamma=0.5: Each successive basket gets 50% of the previous basket's weight — a gradually decaying influence.gamma=1.0: All future baskets are weighted equally.
This is the key lever for controlling how "far ahead" the recommendation looks.
④ Return a sketch
Creates a Sketch object from the item IDs and their weights. The Sketch is a lazy, memory-efficient representation used internally during training.
Training
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping
from monad.ui.module import load_from_foundation_model, RecommendationTask
module = load_from_foundation_model(
checkpoint_path=Path("./foundation_model"),
downstream_task=RecommendationTask(),
target_fn=next_basket_target_fn,
)
training_params = TrainingParams(
checkpoint_dir=Path("./<this_model>"),
learning_rate=1e-4,
epochs=20,
devices=[0],
metrics=[
MetricParams(alias="ndcg", metric_name="NDCGAtK", kwargs={"k": 10}),
MetricParams(alias="hitrate", metric_name="HitRateAtK", kwargs={"k": 10}),
MetricParams(alias="recall", metric_name="MultipleTargetsRecall", kwargs={"k": 10}),
],
metric_to_monitor="val_ndcg_0",
metric_monitoring_mode=MetricMonitoringMode.MAX,
early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)
module.fit(training_params, seed=42)
Evaluation
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
output_type=OutputType.DECODED,
devices=[0],
top_k=10,
metrics=[
MetricParams(alias="ndcg", metric_name="NDCG"),
MetricParams(alias="hitrate", metric_name="HitRate"),
MetricParams(alias="recall", metric_name="Recall"),
],
)
results = module.test(testing_params)
Prediction
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
local_save_location=Path("./predictions.tsv"),
output_type=OutputType.DECODED,
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
devices=[0],
)
predictions = module.predict(testing_params)
Variations
Broader future horizon
Use a higher gamma to consider multiple future baskets, not just the next one:
def next_basket_broad_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> target_function.Sketch:
# === Configuration ===
TRANSACTION_DATA_SOURCE = "transactions"
PRODUCT_ID_COLUMN = "article_id"
GAMMA = 0.5 # Each successive basket gets half the weight of the previous
future_transactions = future[TRANSACTION_DATA_SOURCE]
article_ids = future_transactions[PRODUCT_ID_COLUMN]
training_weights = target_function.sequential_decay(future_transactions, gamma=GAMMA)
return target_function.sketch(article_ids, training_weights)
Exclude previously purchased items
See the Product Acquisition recipe for a variation that filters out repeat purchases.
Recommended Metrics
| Metric | Why it matters |
|---|---|
| NDCG | Measures ranking quality — are the most relevant items ranked highest? |
| Hit Rate | Fraction of customers for whom at least one recommended item was actually purchased. |
| Recall@K | Fraction of actually purchased items that appear in the top-K recommendations. |
Production Tips
-
Choose gamma based on your use case. For "buy it again" features, use
gamma=0(next basket). For discovery-oriented recommendations, usegamma=0.3–0.5to include items from future baskets. -
Filter by availability. Post-process recommendations to remove out-of-stock items before showing them to customers.
-
A/B test recommendation depth. Showing the top 5 vs. top 20 items has different UX trade-offs. Test what works for your platform.
-
Retrain frequently. Product catalogs change fast. Retrain weekly or biweekly to keep recommendations fresh.