Skip to content

Category Purchase Propensity

Task type: MultilabelClassificationTask Industry: Retail / Grocery / FMCG

This recipe scores how likely each customer is to buy from each of several target product categories in the near future. The output is a binary vector indicating which departments the customer will shop in, enabling personalized coupons, category-level promotions, and assortment recommendations.

Why multilabel? A grocery or retail shopper typically buys from multiple departments in a single trip (e.g., Fruits and Dairy and Bakery). Multilabel classification handles this naturally by producing an independent prediction per category.


Prerequisites

Before writing a target function you need:

  • A trained foundation model built on event data that includes a purchases data source.
  • Product category information available either:
  • Directly in the purchases data source (denormalized), or
  • In a separate data source (e.g., articles) — in which case you'll use get_qualified_column_name to reference it.
  • The monad library installed in your environment (for Python App).

Target Function

Argument Type Description
history Events All events before the temporal split.
future Events All events after the temporal split.
attributes Attributes Static entity attributes.
ctx Dict Context dictionary containing SPLIT_TIMESTAMP, data mode, etc.

For multilabel classification, the function must return one of:

  • A 1-D float32 array of size num_labels — binary indicators (0 or 1) per category.
  • Noneexclude this customer (e.g., no future purchases).

Full Example

Python
import numpy as np
from datetime import timedelta
from typing import Dict

from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window

from monad.ui.target_function import get_qualified_column_name


# === Configuration ===
TARGET_WINDOW_DAYS = 21
PURCHASE_DATA_SOURCE = "purchases"
TARGET_CATEGORIES = [
    "Fruits",
    "Dairy",
    "Bakery",
    "Meat and Poultry",
    "Snacks and Confectionery",
    "Beverages",
    "Canned and Packaged Foods",
    "Household Supplies",
    "Personal Care",
    "Cleaning Products",
]

# Use get_qualified_column_name when the category column lives in
# a separate data source (e.g., "articles") rather than directly in "purchases".
CATEGORY_COLUMN = get_qualified_column_name(
    column_name="department",
    data_sources_path=["articles"],
)


def category_propensity_target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Score propensity to buy in each target category (1 = bought, 0 = did not)."""

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window
    future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)

    # 3. Check which categories the customer purchased from
    category_labels, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(CATEGORY_COLUMN)
        .exists(groups=TARGET_CATEGORIES)
    )

    # 4. Exclude customers with no purchases in any target category
    if category_labels.sum() == 0:
        return None

    return category_labels
Python
def category_propensity_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Score propensity to buy in each target category (1 = bought, 0 = did not)."""

    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    PURCHASE_DATA_SOURCE = "purchases"
    TARGET_CATEGORIES = [
        "Fruits",
        "Dairy",
        "Bakery",
        "Meat and Poultry",
        "Snacks and Confectionery",
        "Beverages",
        "Canned and Packaged Foods",
        "Household Supplies",
        "Personal Care",
        "Cleaning Products",
    ]

    # Use target_function.get_qualified_column_name when the category column lives in
    # a separate data source (e.g., "articles") rather than directly in "purchases".
    CATEGORY_COLUMN = target_function.get_qualified_column_name(
        column_name="department",
        data_sources_path=["articles"],
    )

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    # 3. Check which categories the customer purchased from
    category_labels, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(CATEGORY_COLUMN)
        .exists(groups=TARGET_CATEGORIES)
    )

    # 4. Exclude customers with no purchases in any target category
    if category_labels.sum() == 0:
        return None

    return category_labels

Step-by-Step Breakdown

① Validate the training window

Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
    return None
Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
    return None

Skips samples with insufficient future data.

② Trim future events

Python
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
Python
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

Narrows events to exactly 21 days.

③ Detect category purchases

Python
category_labels, _ = (
    future[PURCHASE_DATA_SOURCE]
    .groupBy(CATEGORY_COLUMN)
    .exists(groups=TARGET_CATEGORIES)
)
  • groupBy(CATEGORY_COLUMN) groups purchase events by department.
  • .exists(groups=TARGET_CATEGORIES) returns a binary array: 1 if the customer bought at least one item in that category, 0 otherwise.
  • Example: [1, 1, 0, 0, 1, 1, 0, 0, 0, 0] means the customer bought Fruits, Dairy, Snacks, and Beverages.

Using get_qualified_column_name

Python
CATEGORY_COLUMN = get_qualified_column_name(
    column_name="department",
    data_sources_path=["articles"],
)
Python
CATEGORY_COLUMN = target_function.get_qualified_column_name(
    column_name="department",
    data_sources_path=["articles"],
)

When product attributes (like department) live in a separate data source rather than being denormalized into the purchase events, use get_qualified_column_name to create a qualified reference. This tells the model to resolve the column through the join path purchases → articles → department.

If your data is already denormalized (department is a column directly in purchases), you can simplify to:

Python
CATEGORY_COLUMN = "department"

④ Exclude inactive customers

Python
if category_labels.sum() == 0:
    return None

Customers who made no purchases in any target category are excluded. They provide no positive signal and would dilute the training data.


Training

Python
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping

from monad.ui.module import load_from_foundation_model, MultilabelClassificationTask


module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=MultilabelClassificationTask(class_names=TARGET_CATEGORIES),
    target_fn=category_propensity_target_fn,
)

training_params = TrainingParams(
    checkpoint_dir=Path("./<this_model>"),
    learning_rate=1e-4,
    epochs=20,
    devices=[0],
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC", kwargs={"task": "multilabel", "num_labels": <num_labels>}),
        MetricParams(alias="auprc", metric_name="AveragePrecision", kwargs={"task": "multilabel", "num_labels": <num_labels>}),
        MetricParams(alias="f1", metric_name="F1Score", kwargs={"task": "multilabel", "num_labels": <num_labels>}),
    ],
    metric_to_monitor="val_auroc_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
    early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)

module.fit(training_params, seed=42)

Evaluation

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
    output_type=OutputType.DECODED,
    devices=[0],
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC"),
        MetricParams(alias="auprc", metric_name="AveragePrecision"),
        MetricParams(alias="f1", metric_name="F1Score"),
    ],
)

results = module.test(testing_params)

Prediction

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    local_save_location=Path("./predictions.tsv"),
    output_type=OutputType.DECODED,
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    devices=[0],
)

predictions = module.predict(testing_params)

Variations

Denormalized data (simpler)

If your category column lives directly in the purchases data source:

Python
CATEGORY_COLUMN = "department"  # No get_qualified_column_name needed

Category frequency instead of binary

Use .count(normalize=True) instead of .exists() to get a probability distribution rather than binary flags. This is suitable for MulticlassClassificationTask:

Python
category_distribution, _ = (
    future[PURCHASE_DATA_SOURCE]
    .groupBy(CATEGORY_COLUMN)
    .count(normalize=True, groups=TARGET_CATEGORIES)
)
return category_distribution
Python
def category_frequency_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    PURCHASE_DATA_SOURCE = "purchases"
    CATEGORY_COLUMN = target_function.get_qualified_column_name(
        column_name="department",
        data_sources_path=["articles"],
    )
    TARGET_CATEGORIES = [
        "Fruits", "Dairy", "Bakery", "Meat and Poultry",
        "Snacks and Confectionery", "Beverages",
        "Canned and Packaged Foods", "Household Supplies",
        "Personal Care", "Cleaning Products",
    ]

    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    category_distribution, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(CATEGORY_COLUMN)
        .count(normalize=True, groups=TARGET_CATEGORIES)
    )
    return category_distribution

Metric Why it matters
AUROC (per label) Ranking quality for each category independently.
AUPRC (per label) Better than AUROC for categories with low purchase rates.
F1 Score (micro) Overall balance across all categories.
Hamming Loss Fraction of category labels incorrectly predicted.

Production Tips

  1. Personalize promotions by predicted categories. Send Dairy coupons to customers predicted to buy Dairy, not a one-size-fits-all newsletter.

  2. Monitor per-category performance. Some categories (Fruits, Dairy) are purchased by most customers and are easy to predict. Niche categories may need more training data.

  3. Combine with basket value predictions. Pair category propensity with spend regression to estimate not just what a customer will buy, but how much they will spend.

  4. Retrain after seasonal shifts. Demand patterns change with holidays and seasons — retrain quarterly at minimum.