Skip to content

Favorite Color Prediction

Task type: MulticlassClassificationTask Industry: Retail / Fashion

This recipe predicts a customer's preferred product color based on their purchase history. The output is a probability distribution across a defined set of colors, useful for personalized product recommendations, homepage merchandising, and targeted email campaigns.

How does this work? The model counts how many items of each color the customer purchases in the future window, normalizes the counts into a distribution, and learns to predict it from historical behavior.


Prerequisites

Before writing a target function you need:

  • A trained foundation model built on event data that includes a transactions data source with a Colour column (or equivalent).
  • The monad library installed in your environment (for Python App).

Target Function

Argument Type Description
history Events All events before the temporal split.
future Events All events after the temporal split.
attributes Attributes Static entity attributes.
ctx Dict Context dictionary containing SPLIT_TIMESTAMP, data mode, etc.

For multiclass classification, the function must return one of:

  • A 1-D float32 array of size num_classes — normalized purchase counts across the target colors.
  • Noneexclude this customer from training.

Full Example

Python
import numpy as np
from datetime import timedelta
from typing import Dict

from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window



# === Configuration ===
TARGET_WINDOW_DAYS = 21
TRANSACTION_DATA_SOURCE = "transactions"
COLOR_COLUMN = "Colour"
TARGET_COLORS = [
    "Red", "Blue", "Green", "Yellow", "Purple", "Orange",
    "Pink", "Brown", "Black", "White", "Gray", "Teal",
]


def favourite_color_target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Return a normalized color-purchase distribution for the customer."""

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window
    future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)

    # 3. Count purchases per color, normalized to a probability distribution
    color_distribution, _ = (
        future[TRANSACTION_DATA_SOURCE]
        .groupBy(COLOR_COLUMN)
        .count(normalize=True, groups=TARGET_COLORS)
    )

    return color_distribution
Python
def favourite_color_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Return a normalized color-purchase distribution for the customer."""

    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    TRANSACTION_DATA_SOURCE = "transactions"
    COLOR_COLUMN = "Colour"
    TARGET_COLORS = [
        "Red", "Blue", "Green", "Yellow", "Purple", "Orange",
        "Pink", "Brown", "Black", "White", "Gray", "Teal",
    ]

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    # 3. Count purchases per color, normalized to a probability distribution
    color_distribution, _ = (
        future[TRANSACTION_DATA_SOURCE]
        .groupBy(COLOR_COLUMN)
        .count(normalize=True, groups=TARGET_COLORS)
    )

    return color_distribution

Step-by-Step Breakdown

① Validate the training window

Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
    return None
Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
    return None

Skips samples where the split leaves insufficient future data.

② Trim future events

Python
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
Python
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

Narrows events to exactly 21 days for a consistent evaluation horizon.

③ Count and normalize color purchases

Python
color_distribution, _ = (
    future[TRANSACTION_DATA_SOURCE]
    .groupBy(COLOR_COLUMN)
    .count(normalize=True, groups=TARGET_COLORS)
)
  • groupBy(COLOR_COLUMN) groups transactions by color.
  • .count(normalize=True, groups=TARGET_COLORS) counts per group, normalizes to sum to 1, and returns results for the specified colors only.
  • Returns a tuple (np.ndarray, List[str]) — we take only the array.
  • If a customer bought 2 Black items and 1 Blue item, the output array would have 0.67 at the Black position and 0.33 at Blue, with zeros elsewhere.

Note: groupBy().count() produces a float64 array. The Task layer accepts it as-is — no manual astype(np.float32) is required.


Training

Python
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping

from monad.ui.module import load_from_foundation_model, MulticlassClassificationTask


module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=MulticlassClassificationTask(class_names=TARGET_COLORS),
    target_fn=favourite_color_target_fn,
)

training_params = TrainingParams(
    checkpoint_dir=Path("./<this_model>"),
    learning_rate=1e-4,
    epochs=20,
    devices=[0],
    metrics=[
        MetricParams(alias="accuracy", metric_name="Accuracy", kwargs={"task": "multiclass", "num_classes": <num_classes>}),
        MetricParams(alias="f1_macro", metric_name="F1Score", kwargs={"task": "multiclass", "num_classes": <num_classes>, "average": "macro"}),
    ],
    metric_to_monitor="val_accuracy_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
    early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)

module.fit(training_params, seed=42)

Evaluation

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
    output_type=OutputType.DECODED,
    devices=[0],
    metrics=[
        MetricParams(alias="accuracy", metric_name="Accuracy"),
        MetricParams(alias="f1_macro", metric_name="F1Score"),
    ],
)

results = module.test(testing_params)

Prediction

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    local_save_location=Path("./predictions.tsv"),
    output_type=OutputType.DECODED,
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    devices=[0],
)

predictions = module.predict(testing_params)

Variations

Seasonal color preferences

Shorten the target window and retrain per season to capture seasonal shifts (e.g., darker tones in winter, brighter in summer):

Python
TARGET_WINDOW_DAYS = 14  # Shorter for seasonal sensitivity
Python
def seasonal_color_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    # === Configuration ===
    TARGET_WINDOW_DAYS = 14  # Shorter for seasonal sensitivity
    TRANSACTION_DATA_SOURCE = "transactions"
    COLOR_COLUMN = "Colour"
    TARGET_COLORS = [
        "Red", "Blue", "Green", "Yellow", "Purple", "Orange",
        "Pink", "Brown", "Black", "White", "Gray", "Teal",
    ]

    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    color_distribution, _ = (
        future[TRANSACTION_DATA_SOURCE]
        .groupBy(COLOR_COLUMN)
        .count(normalize=True, groups=TARGET_COLORS)
    )
    return color_distribution

Exclude customers with no color diversity

Skip customers who only buy one color — they provide a trivial signal:

Python
# Add after step 3:
if np.count_nonzero(color_distribution) <= 1:
    return None
Python
def diverse_color_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    # === Configuration ===
    TARGET_WINDOW_DAYS = 21
    TRANSACTION_DATA_SOURCE = "transactions"
    COLOR_COLUMN = "Colour"
    TARGET_COLORS = [
        "Red", "Blue", "Green", "Yellow", "Purple", "Orange",
        "Pink", "Brown", "Black", "White", "Gray", "Teal",
    ]

    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    color_distribution, _ = (
        future[TRANSACTION_DATA_SOURCE]
        .groupBy(COLOR_COLUMN)
        .count(normalize=True, groups=TARGET_COLORS)
    )

    # Exclude customers who only buy one color
    if np.count_nonzero(color_distribution) <= 1:
        return None

    return color_distribution

Metric Why it matters
Accuracy Fraction of customers whose top-predicted color matches their actual top color.
F1 Score (macro) Balances performance across all colors, important when some colors are rare.
Top-3 Accuracy Checks if the true top color is among the model's top 3 predictions.

Production Tips

  1. Merge similar colors. If "Gray" and "Silver" or "Beige" and "Cream" appear as separate values in your data, consider merging them to reduce noise.

  2. Use predictions in product ranking. When a customer searches or browses, boost products in their preferred colors.

  3. Combine with brand preferences. Color + brand predictions together create a powerful personalization signal.

  4. Retrain after assortment refreshes. New collections may introduce new colors or retire old ones — keep TARGET_COLORS in sync.