Category Purchase Propensity
Task type: MultilabelClassificationTask
Industry: Retail / Grocery / FMCG
This recipe scores how likely each customer is to buy from each of several target product categories in the near future. The output is a binary vector indicating which departments the customer will shop in, enabling personalized coupons, category-level promotions, and assortment recommendations.
Why multilabel? A grocery or retail shopper typically buys from multiple departments in a single trip (e.g., Fruits and Dairy and Bakery). Multilabel classification handles this naturally by producing an independent prediction per category.
Prerequisites
Before writing a target function you need:
- A trained foundation model built on event data that includes a
purchasesdata source. - Product category information available either:
- Directly in the purchases data source (denormalized), or
- In a separate data source (e.g.,
articles) — in which case you'll useget_qualified_column_nameto reference it. - The monad library installed in your environment (for Python App).
Target Function
| Argument | Type | Description |
|---|---|---|
history |
Events |
All events before the temporal split. |
future |
Events |
All events after the temporal split. |
attributes |
Attributes |
Static entity attributes. |
ctx |
Dict |
Context dictionary containing SPLIT_TIMESTAMP, data mode, etc. |
For multilabel classification, the function must return one of:
- A 1-D
float32array of sizenum_labels— binary indicators (0or1) per category. None— exclude this customer (e.g., no future purchases).
Full Example
import numpy as np
from datetime import timedelta
from typing import Dict
from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window
from monad.ui.target_function import get_qualified_column_name
# === Configuration ===
TARGET_WINDOW_DAYS = 21
PURCHASE_DATA_SOURCE = "purchases"
TARGET_CATEGORIES = [
"Fruits",
"Dairy",
"Bakery",
"Meat and Poultry",
"Snacks and Confectionery",
"Beverages",
"Canned and Packaged Foods",
"Household Supplies",
"Personal Care",
"Cleaning Products",
]
# Use get_qualified_column_name when the category column lives in
# a separate data source (e.g., "articles") rather than directly in "purchases".
CATEGORY_COLUMN = get_qualified_column_name(
column_name="department",
data_sources_path=["articles"],
)
def category_propensity_target_fn(
history: Events,
future: Events,
attributes: Attributes,
ctx: Dict,
) -> np.ndarray | None:
"""Score propensity to buy in each target category (1 = bought, 0 = did not)."""
# 1. Ensure the training window is long enough
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
return None
# 2. Trim future events to the target window
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
# 3. Check which categories the customer purchased from
category_labels, _ = (
future[PURCHASE_DATA_SOURCE]
.groupBy(CATEGORY_COLUMN)
.exists(groups=TARGET_CATEGORIES)
)
# 4. Exclude customers with no purchases in any target category
if category_labels.sum() == 0:
return None
return category_labels
def category_propensity_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
"""Score propensity to buy in each target category (1 = bought, 0 = did not)."""
# === Configuration ===
TARGET_WINDOW_DAYS = 21
PURCHASE_DATA_SOURCE = "purchases"
TARGET_CATEGORIES = [
"Fruits",
"Dairy",
"Bakery",
"Meat and Poultry",
"Snacks and Confectionery",
"Beverages",
"Canned and Packaged Foods",
"Household Supplies",
"Personal Care",
"Cleaning Products",
]
# Use target_function.get_qualified_column_name when the category column lives in
# a separate data source (e.g., "articles") rather than directly in "purchases".
CATEGORY_COLUMN = target_function.get_qualified_column_name(
column_name="department",
data_sources_path=["articles"],
)
# 1. Ensure the training window is long enough
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
# 2. Trim future events to the target window
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
# 3. Check which categories the customer purchased from
category_labels, _ = (
future[PURCHASE_DATA_SOURCE]
.groupBy(CATEGORY_COLUMN)
.exists(groups=TARGET_CATEGORIES)
)
# 4. Exclude customers with no purchases in any target category
if category_labels.sum() == 0:
return None
return category_labels
Step-by-Step Breakdown
① Validate the training window
Skips samples with insufficient future data.
② Trim future events
Narrows events to exactly 21 days.
③ Detect category purchases
category_labels, _ = (
future[PURCHASE_DATA_SOURCE]
.groupBy(CATEGORY_COLUMN)
.exists(groups=TARGET_CATEGORIES)
)
groupBy(CATEGORY_COLUMN)groups purchase events by department..exists(groups=TARGET_CATEGORIES)returns a binary array:1if the customer bought at least one item in that category,0otherwise.- Example:
[1, 1, 0, 0, 1, 1, 0, 0, 0, 0]means the customer bought Fruits, Dairy, Snacks, and Beverages.
Using get_qualified_column_name
When product attributes (like department) live in a separate data source rather than being denormalized into the purchase events, use get_qualified_column_name to create a qualified reference. This tells the model to resolve the column through the join path purchases → articles → department.
If your data is already denormalized (department is a column directly in purchases), you can simplify to:
④ Exclude inactive customers
Customers who made no purchases in any target category are excluded. They provide no positive signal and would dilute the training data.
Training
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping
from monad.ui.module import load_from_foundation_model, MultilabelClassificationTask
module = load_from_foundation_model(
checkpoint_path=Path("./foundation_model"),
downstream_task=MultilabelClassificationTask(class_names=TARGET_CATEGORIES),
target_fn=category_propensity_target_fn,
)
training_params = TrainingParams(
checkpoint_dir=Path("./<this_model>"),
learning_rate=1e-4,
epochs=20,
devices=[0],
metrics=[
MetricParams(alias="auroc", metric_name="AUROC", kwargs={"task": "multilabel", "num_labels": <num_labels>}),
MetricParams(alias="auprc", metric_name="AveragePrecision", kwargs={"task": "multilabel", "num_labels": <num_labels>}),
MetricParams(alias="f1", metric_name="F1Score", kwargs={"task": "multilabel", "num_labels": <num_labels>}),
],
metric_to_monitor="val_auroc_0",
metric_monitoring_mode=MetricMonitoringMode.MAX,
early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)
module.fit(training_params, seed=42)
Evaluation
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
output_type=OutputType.DECODED,
devices=[0],
metrics=[
MetricParams(alias="auroc", metric_name="AUROC"),
MetricParams(alias="auprc", metric_name="AveragePrecision"),
MetricParams(alias="f1", metric_name="F1Score"),
],
)
results = module.test(testing_params)
Prediction
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
local_save_location=Path("./predictions.tsv"),
output_type=OutputType.DECODED,
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
devices=[0],
)
predictions = module.predict(testing_params)
Variations
Denormalized data (simpler)
If your category column lives directly in the purchases data source:
Category frequency instead of binary
Use .count(normalize=True) instead of .exists() to get a probability distribution rather than binary flags. This is suitable for MulticlassClassificationTask:
def category_frequency_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
# === Configuration ===
TARGET_WINDOW_DAYS = 21
PURCHASE_DATA_SOURCE = "purchases"
CATEGORY_COLUMN = target_function.get_qualified_column_name(
column_name="department",
data_sources_path=["articles"],
)
TARGET_CATEGORIES = [
"Fruits", "Dairy", "Bakery", "Meat and Poultry",
"Snacks and Confectionery", "Beverages",
"Canned and Packaged Foods", "Household Supplies",
"Personal Care", "Cleaning Products",
]
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
category_distribution, _ = (
future[PURCHASE_DATA_SOURCE]
.groupBy(CATEGORY_COLUMN)
.count(normalize=True, groups=TARGET_CATEGORIES)
)
return category_distribution
Recommended Metrics
| Metric | Why it matters |
|---|---|
| AUROC (per label) | Ranking quality for each category independently. |
| AUPRC (per label) | Better than AUROC for categories with low purchase rates. |
| F1 Score (micro) | Overall balance across all categories. |
| Hamming Loss | Fraction of category labels incorrectly predicted. |
Production Tips
-
Personalize promotions by predicted categories. Send Dairy coupons to customers predicted to buy Dairy, not a one-size-fits-all newsletter.
-
Monitor per-category performance. Some categories (Fruits, Dairy) are purchased by most customers and are easy to predict. Niche categories may need more training data.
-
Combine with basket value predictions. Pair category propensity with spend regression to estimate not just what a customer will buy, but how much they will spend.
-
Retrain after seasonal shifts. Demand patterns change with holidays and seasons — retrain quarterly at minimum.