Skip to content

Favorite Brand Prediction

Task type: MulticlassClassificationTask Industry: Retail / Fashion

This recipe predicts which brand a customer is most likely to buy from in the near future. The output is a probability distribution across a defined set of target brands, making it useful for personalized brand recommendations, targeted marketing, and assortment planning.

How does this work? The model learns each customer's brand affinity from their transaction history. The target function counts purchases per brand in the future window and normalizes them into a probability distribution.


Prerequisites

Before writing a target function you need:

  • A trained foundation model built on event data that includes a purchases data source with a Brand column (or equivalent).
  • The monad library installed in your environment (for Python App).

Target Function

The target function tells the model how to label each customer for training. It receives four arguments:

Argument Type Description
history Events All events before the temporal split.
future Events All events after the temporal split.
attributes Attributes Static entity attributes.
ctx Dict Context dictionary containing SPLIT_TIMESTAMP, data mode, etc.

For multiclass classification, the function must return one of:

  • A 1-D float32 array of size num_classes — normalized purchase counts (probability distribution) across the target brands.
  • Noneexclude this customer from training (e.g., incomplete window).

Full Example

Python
import numpy as np
from datetime import timedelta
from typing import Dict

from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window



# === Configuration ===
TARGET_WINDOW_DAYS = 21
PURCHASE_DATA_SOURCE = "purchases"
BRAND_COLUMN = "brand"
TARGET_BRANDS = [
    "The North Face",
    "Adidas",
    "Tommy Hilfiger",
    "Hugo",
    "Lacoste",
    "Gap",
]


def favourite_brand_target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Return a normalized brand-purchase distribution for the customer."""

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window
    future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)

    # 3. Count purchases per brand, normalized to a probability distribution
    brand_distribution, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(BRAND_COLUMN)
        .count(normalize=True, groups=TARGET_BRANDS)
    )

    return brand_distribution
Python
def favourite_brand_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Return a normalized brand-purchase distribution for the customer."""

    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    PURCHASE_DATA_SOURCE = "purchases"
    BRAND_COLUMN = "brand"
    TARGET_BRANDS = [
        "The North Face",
        "Adidas",
        "Tommy Hilfiger",
        "Hugo",
        "Lacoste",
        "Gap",
    ]

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    # 3. Count purchases per brand, normalized to a probability distribution
    brand_distribution, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(BRAND_COLUMN)
        .count(normalize=True, groups=TARGET_BRANDS)
    )

    return brand_distribution

Step-by-Step Breakdown

① Validate the training window

Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
    return None
Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
    return None

Skips samples where the temporal split leaves less than 21 days of observable future. Automatically bypassed at test/prediction time.

② Trim future events to the target window

Python
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
Python
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

Ensures consistent evaluation horizon across all samples.

③ Count and normalize brand purchases

Python
brand_distribution, _ = (
    future[PURCHASE_DATA_SOURCE]
    .groupBy(BRAND_COLUMN)
    .count(normalize=True, groups=TARGET_BRANDS)
)

This is the core logic:

  • groupBy(BRAND_COLUMN) groups future purchase events by brand.
  • .count(normalize=True, groups=TARGET_BRANDS) counts events in each group, normalizes to sum to 1, and returns results only for the specified brands.
  • The return type is a tuple (np.ndarray, List[str]). We take only the array.
  • If a customer bought 3 Adidas items and 1 Gap item, the output would be [0.0, 0.75, 0.0, 0.0, 0.0, 0.25].

Note: groupBy().count() returns a float64 array. The Task layer accepts it as-is — no manual astype(np.float32) is required.


Training

Python
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping

from monad.ui.module import load_from_foundation_model, MulticlassClassificationTask


module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=MulticlassClassificationTask(class_names=TARGET_BRANDS),
    target_fn=favourite_brand_target_fn,
)

training_params = TrainingParams(
    checkpoint_dir=Path("./<this_model>"),
    learning_rate=1e-4,
    epochs=20,
    devices=[0],
    metrics=[
        MetricParams(alias="accuracy", metric_name="Accuracy", kwargs={"task": "multiclass", "num_classes": <num_classes>}),
        MetricParams(alias="f1_macro", metric_name="F1Score", kwargs={"task": "multiclass", "num_classes": <num_classes>, "average": "macro"}),
    ],
    metric_to_monitor="val_accuracy_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
    early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)

module.fit(training_params, seed=42)

Evaluation

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
    output_type=OutputType.DECODED,
    devices=[0],
    metrics=[
        MetricParams(alias="accuracy", metric_name="Accuracy"),
        MetricParams(alias="f1_macro", metric_name="F1Score"),
    ],
)

results = module.test(testing_params)

Prediction

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    local_save_location=Path("./predictions.tsv"),
    output_type=OutputType.DECODED,
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    devices=[0],
)

predictions = module.predict(testing_params)

Variations

Top brand only (argmax)

If you only need the single most-likely brand rather than a distribution:

Python
brand_distribution, _ = (
    future[PURCHASE_DATA_SOURCE]
    .groupBy(BRAND_COLUMN)
    .count(normalize=False, groups=TARGET_BRANDS)
)
# One-hot encode the most purchased brand
result = np.zeros(len(TARGET_BRANDS), dtype=np.float32)
if brand_distribution.sum() > 0:
    result[np.argmax(brand_distribution)] = 1.0
    return result
return None  # No purchases — exclude
Python
def top_brand_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    PURCHASE_DATA_SOURCE = "purchases"
    BRAND_COLUMN = "brand"
    TARGET_BRANDS = [
        "The North Face", "Adidas", "Tommy Hilfiger",
        "Hugo", "Lacoste", "Gap",
    ]

    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    brand_distribution, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(BRAND_COLUMN)
        .count(normalize=False, groups=TARGET_BRANDS)
    )
    # One-hot encode the most purchased brand
    result = np.zeros(len(TARGET_BRANDS), dtype=np.float32)
    if brand_distribution.sum() > 0:
        result[np.argmax(brand_distribution)] = 1.0
        return result
    return None  # No purchases — exclude

Brand affinity with history weighting

Blend historical and future purchases for a smoother signal:

Python
history_dist, _ = (
    history[PURCHASE_DATA_SOURCE]
    .groupBy(BRAND_COLUMN)
    .count(normalize=True, groups=TARGET_BRANDS)
)
future_dist, _ = (
    future[PURCHASE_DATA_SOURCE]
    .groupBy(BRAND_COLUMN)
    .count(normalize=True, groups=TARGET_BRANDS)
)
# 70% future, 30% history
blended = 0.7 * future_dist + 0.3 * history_dist
blended = blended / blended.sum()  # Re-normalize
return blended
Python
def brand_affinity_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    PURCHASE_DATA_SOURCE = "purchases"
    BRAND_COLUMN = "brand"
    TARGET_BRANDS = [
        "The North Face", "Adidas", "Tommy Hilfiger",
        "Hugo", "Lacoste", "Gap",
    ]

    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    history_dist, _ = (
        history[PURCHASE_DATA_SOURCE]
        .groupBy(BRAND_COLUMN)
        .count(normalize=True, groups=TARGET_BRANDS)
    )
    future_dist, _ = (
        future[PURCHASE_DATA_SOURCE]
        .groupBy(BRAND_COLUMN)
        .count(normalize=True, groups=TARGET_BRANDS)
    )
    # 70% future, 30% history
    blended = 0.7 * future_dist + 0.3 * history_dist
    blended = blended / blended.sum()  # Re-normalize
    return blended

Metric Why it matters
Accuracy Fraction of customers whose top-predicted brand matches their actual top brand.
F1 Score (macro) Balances precision and recall across all brands — important when some brands are rare.
Top-k Accuracy Checks if the true brand is among the top k predictions. Useful when near-misses are acceptable.

Production Tips

  1. Keep the brand list current. Remove discontinued brands and add new ones. A stale list trains the model on irrelevant categories.

  2. Balance class frequencies. If one brand dominates purchases, the model may over-predict it. Consider requiring a minimum purchase count per brand or using class weighting.

  3. Use predictions for personalization. Feed brand probabilities into email templates, app banners, or recommendation carousels to surface relevant brands per customer.

  4. Retrain after major assortment changes. Brand preferences shift when new collections launch or a brand exits your store.