Favorite Color Prediction
Task type: MulticlassClassificationTask
Industry: Retail / Fashion
This recipe predicts a customer's preferred product color based on their purchase history. The output is a probability distribution across a defined set of colors, useful for personalized product recommendations, homepage merchandising, and targeted email campaigns.
How does this work? The model counts how many items of each color the customer purchases in the future window, normalizes the counts into a distribution, and learns to predict it from historical behavior.
Prerequisites
Before writing a target function you need:
- A trained foundation model built on event data that includes a
transactionsdata source with aColourcolumn (or equivalent). - The monad library installed in your environment (for Python App).
Target Function
| Argument | Type | Description |
|---|---|---|
history |
Events |
All events before the temporal split. |
future |
Events |
All events after the temporal split. |
attributes |
Attributes |
Static entity attributes. |
ctx |
Dict |
Context dictionary containing SPLIT_TIMESTAMP, data mode, etc. |
For multiclass classification, the function must return one of:
- A 1-D
float32array of sizenum_classes— normalized purchase counts across the target colors. None— exclude this customer from training.
Full Example
import numpy as np
from datetime import timedelta
from typing import Dict
from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window
# === Configuration ===
TARGET_WINDOW_DAYS = 21
TRANSACTION_DATA_SOURCE = "transactions"
COLOR_COLUMN = "Colour"
TARGET_COLORS = [
"Red", "Blue", "Green", "Yellow", "Purple", "Orange",
"Pink", "Brown", "Black", "White", "Gray", "Teal",
]
def favourite_color_target_fn(
history: Events,
future: Events,
attributes: Attributes,
ctx: Dict,
) -> np.ndarray | None:
"""Return a normalized color-purchase distribution for the customer."""
# 1. Ensure the training window is long enough
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
return None
# 2. Trim future events to the target window
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
# 3. Count purchases per color, normalized to a probability distribution
color_distribution, _ = (
future[TRANSACTION_DATA_SOURCE]
.groupBy(COLOR_COLUMN)
.count(normalize=True, groups=TARGET_COLORS)
)
return color_distribution
def favourite_color_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
"""Return a normalized color-purchase distribution for the customer."""
# === Configuration ===
TARGET_WINDOW_DAYS = 21
TRANSACTION_DATA_SOURCE = "transactions"
COLOR_COLUMN = "Colour"
TARGET_COLORS = [
"Red", "Blue", "Green", "Yellow", "Purple", "Orange",
"Pink", "Brown", "Black", "White", "Gray", "Teal",
]
# 1. Ensure the training window is long enough
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
# 2. Trim future events to the target window
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
# 3. Count purchases per color, normalized to a probability distribution
color_distribution, _ = (
future[TRANSACTION_DATA_SOURCE]
.groupBy(COLOR_COLUMN)
.count(normalize=True, groups=TARGET_COLORS)
)
return color_distribution
Step-by-Step Breakdown
① Validate the training window
Skips samples where the split leaves insufficient future data.
② Trim future events
Narrows events to exactly 21 days for a consistent evaluation horizon.
③ Count and normalize color purchases
color_distribution, _ = (
future[TRANSACTION_DATA_SOURCE]
.groupBy(COLOR_COLUMN)
.count(normalize=True, groups=TARGET_COLORS)
)
groupBy(COLOR_COLUMN)groups transactions by color..count(normalize=True, groups=TARGET_COLORS)counts per group, normalizes to sum to 1, and returns results for the specified colors only.- Returns a tuple
(np.ndarray, List[str])— we take only the array. - If a customer bought 2 Black items and 1 Blue item, the output array would have
0.67at the Black position and0.33at Blue, with zeros elsewhere.
Note:
groupBy().count()produces afloat64array. The Task layer accepts it as-is — no manualastype(np.float32)is required.
Training
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping
from monad.ui.module import load_from_foundation_model, MulticlassClassificationTask
module = load_from_foundation_model(
checkpoint_path=Path("./foundation_model"),
downstream_task=MulticlassClassificationTask(class_names=TARGET_COLORS),
target_fn=favourite_color_target_fn,
)
training_params = TrainingParams(
checkpoint_dir=Path("./<this_model>"),
learning_rate=1e-4,
epochs=20,
devices=[0],
metrics=[
MetricParams(alias="accuracy", metric_name="Accuracy", kwargs={"task": "multiclass", "num_classes": <num_classes>}),
MetricParams(alias="f1_macro", metric_name="F1Score", kwargs={"task": "multiclass", "num_classes": <num_classes>, "average": "macro"}),
],
metric_to_monitor="val_accuracy_0",
metric_monitoring_mode=MetricMonitoringMode.MAX,
early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)
module.fit(training_params, seed=42)
Evaluation
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
output_type=OutputType.DECODED,
devices=[0],
metrics=[
MetricParams(alias="accuracy", metric_name="Accuracy"),
MetricParams(alias="f1_macro", metric_name="F1Score"),
],
)
results = module.test(testing_params)
Prediction
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
local_save_location=Path("./predictions.tsv"),
output_type=OutputType.DECODED,
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
devices=[0],
)
predictions = module.predict(testing_params)
Variations
Seasonal color preferences
Shorten the target window and retrain per season to capture seasonal shifts (e.g., darker tones in winter, brighter in summer):
def seasonal_color_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
# === Configuration ===
TARGET_WINDOW_DAYS = 14 # Shorter for seasonal sensitivity
TRANSACTION_DATA_SOURCE = "transactions"
COLOR_COLUMN = "Colour"
TARGET_COLORS = [
"Red", "Blue", "Green", "Yellow", "Purple", "Orange",
"Pink", "Brown", "Black", "White", "Gray", "Teal",
]
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
color_distribution, _ = (
future[TRANSACTION_DATA_SOURCE]
.groupBy(COLOR_COLUMN)
.count(normalize=True, groups=TARGET_COLORS)
)
return color_distribution
Exclude customers with no color diversity
Skip customers who only buy one color — they provide a trivial signal:
def diverse_color_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
# === Configuration ===
TARGET_WINDOW_DAYS = 21
TRANSACTION_DATA_SOURCE = "transactions"
COLOR_COLUMN = "Colour"
TARGET_COLORS = [
"Red", "Blue", "Green", "Yellow", "Purple", "Orange",
"Pink", "Brown", "Black", "White", "Gray", "Teal",
]
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
color_distribution, _ = (
future[TRANSACTION_DATA_SOURCE]
.groupBy(COLOR_COLUMN)
.count(normalize=True, groups=TARGET_COLORS)
)
# Exclude customers who only buy one color
if np.count_nonzero(color_distribution) <= 1:
return None
return color_distribution
Recommended Metrics
| Metric | Why it matters |
|---|---|
| Accuracy | Fraction of customers whose top-predicted color matches their actual top color. |
| F1 Score (macro) | Balances performance across all colors, important when some colors are rare. |
| Top-3 Accuracy | Checks if the true top color is among the model's top 3 predictions. |
Production Tips
-
Merge similar colors. If "Gray" and "Silver" or "Beige" and "Cream" appear as separate values in your data, consider merging them to reduce noise.
-
Use predictions in product ranking. When a customer searches or browses, boost products in their preferred colors.
-
Combine with brand preferences. Color + brand predictions together create a powerful personalization signal.
-
Retrain after assortment refreshes. New collections may introduce new colors or retire old ones — keep
TARGET_COLORSin sync.