Skip to content

Predict Time to New Category Purchase

Task type: RegressionTask Industry: Retail

Category expansion is a key indicator of customer engagement deepening. When a customer ventures into a new product category, it signals growing trust in the brand and higher lifetime value potential. By predicting how soon a customer will explore a new category, merchandising teams can accelerate the transition with targeted cross-category recommendations and introductory offers.

What makes this advanced? Set-based tracking — builds a set of historical categories, scans future events to find the first new one.


Prerequisites

Before writing a target function you need:

  • A trained foundation model built on event data that includes the relevant data sources.
  • The monad library installed in your environment.
  • Data source(s): transactions with a category column

Target Function

The target function tells monad how to label each entity for training. It receives four arguments:

Argument Type Description
history Events All events before the temporal split.
future Events All events after the temporal split.
attributes Attributes Static entity attributes.
ctx Dict Context dictionary containing SPLIT_TIMESTAMP, data mode, etc.

For regression tasks, the function must return one of:

  • np.array([value], dtype=np.float32) — the predicted continuous value (days until new category purchase).
  • Noneexclude this entity (e.g., no new category purchase found).

Full Example

Python
import numpy as np
from datetime import timedelta
from typing import Dict

from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window

from monad.constants import SECONDS_PER_DAY

# === Configuration ===
TRANSACTION_DATA_SOURCE = "transactions"

def time_to_new_category_target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Predict days until first purchase in a new category."""

    split_ts = ctx[SPLIT_TIMESTAMP]
    known_categories = set(
        history[TRANSACTION_DATA_SOURCE]["category"].events
    )

    for category, ts in zip(
        future[TRANSACTION_DATA_SOURCE]["category"].events,
        future[TRANSACTION_DATA_SOURCE].timestamps,
    ):
        if category not in known_categories:
            days = (ts - split_ts) // SECONDS_PER_DAY
            return np.array([days], dtype=np.float32)

    return None

Step-by-Step Breakdown

① Build the set of known categories

Python
known_categories = set(
    history[TRANSACTION_DATA_SOURCE]["category"].events
)

All categories the customer has purchased from in their history are collected into a set. This represents the customer's current category repertoire — any category not in this set would be a "new" category.

② Scan future events for the first new category

Python
for category, ts in zip(
    future[TRANSACTION_DATA_SOURCE]["category"].events,
    future[TRANSACTION_DATA_SOURCE].timestamps,
):
    if category not in known_categories:
        days = (ts - split_ts) // SECONDS_PER_DAY
        return np.array([days], dtype=np.float32)

Future transactions are iterated in chronological order. The first transaction in a category not seen in history triggers a return. Integer division by 86,400 converts the timestamp difference to whole days. The loop exits immediately on the first match — subsequent new-category purchases are ignored.

③ Handle no-new-category case

Python
return None

If the customer only purchases from categories they have bought from before, the sample is excluded. This is a censored observation — the customer may explore new categories later, but the available data does not capture it.

④ Note: no explicit window validation

This target function does not use has_incomplete_training_window because there is no fixed target window. The function searches all available future data for the first new-category event. This is appropriate for time-to-event targets where the event may occur at any point in the future.


Training

Once the target function is defined, fine-tune a downstream model:

Python
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping

from monad.ui.module import load_from_foundation_model, RegressionTask

module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=RegressionTask(num_targets=1),
    target_fn=time_to_new_category_target_fn,
)

training_params = TrainingParams(
    checkpoint_dir=Path("./<this_model>"),
    learning_rate=1e-4,
    epochs=20,
    devices=[0],
    metrics=[
        MetricParams(alias="mae", metric_name="MeanAbsoluteError"),
        MetricParams(alias="mse", metric_name="MeanSquaredError"),
        MetricParams(alias="r2", metric_name="R2Score"),
    ],
    metric_to_monitor="val_mae_0",
    metric_monitoring_mode=MetricMonitoringMode.MIN,
    early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)

module.fit(training_params, seed=42)

Evaluation

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
    output_type=OutputType.DECODED,
    devices=[0],
    metrics=[
        MetricParams(alias="mae", metric_name="MeanAbsoluteError"),
        MetricParams(alias="mse", metric_name="MeanSquaredError"),
        MetricParams(alias="r2", metric_name="R2Score"),
    ],
)

results = module.test(testing_params)

Prediction

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    local_save_location=Path("./predictions.tsv"),
    output_type=OutputType.DECODED,
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    devices=[0],
)

predictions = module.predict(testing_params)

Metric Why it matters
MAE Average absolute error — intuitive and robust to outliers.
RMSE Penalises large errors more heavily than MAE.
Proportion of variance explained by the model.
MAPE Percentage-based error — useful for comparing across scales.

Production Tips

  1. Define "new category" at the right granularity. A category hierarchy (department > category > subcategory) offers different levels of novelty. Predicting a new subcategory is easier and more frequent than a new department.
  2. Add a maximum horizon. Without a window cap, the target can range from 0 to hundreds of days. Consider capping at 180 or 365 days and excluding samples beyond that to reduce label variance.
  3. Pair with a classification model. Combine this regression model with a binary classifier that predicts whether the customer will explore a new category at all. Use the regression prediction only when the classifier says yes.
  4. Use predictions for cross-category campaigns. Customers predicted to explore a new category within 7-14 days are prime targets for introductory offers in adjacent categories.
  5. Monitor category coverage. Customers with very few known categories have more opportunities for new-category purchases. Segment predictions by category breadth for fairer evaluation.