Skip to content

Customer Churn Prediction

Task type: BinaryClassificationTask Industry: Retail / E-commerce

Churn prediction identifies customers who are likely to stop buying within a defined future window. This lets your team act early — trigger retention campaigns, offer incentives, or flag accounts for customer success outreach.

What does "churn" mean here? A customer is considered churned if they make zero purchases in the next N days. You control the window length.


Prerequisites

Before writing a target function you need:

  • A trained foundation model built on event data that includes customer transactions (e.g., a purchases data source).
  • The monad library installed in your environment (for Python App).

Target Function

The target function tells the model how to label each customer for training. It receives four arguments:

Argument Type Description
history Events All events before the temporal split.
future Events All events after the temporal split.
attributes Attributes Static entity attributes.
ctx Dict Context dictionary containing SPLIT_TIMESTAMP, data mode, etc.

The function must return one of:

  • np.array([1], dtype=np.float32) — customer churned
  • np.array([0], dtype=np.float32) — customer active
  • Noneexclude this customer from training (e.g., insufficient history)

Full Example

Python
import numpy as np
from datetime import timedelta
from typing import Dict

from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window



# === Configuration ===
TARGET_WINDOW_DAYS = 21          # How far into the future to look
PURCHASE_DATA_SOURCE = "purchases"  # Name of your transaction data source


def churn_target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Label a customer as churned (1) or active (0)."""

    # 1. Ensure the training window is long enough
    #    If the split is too close to the end of available data,
    #    we cannot observe a full target window — skip this sample.
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window only
    future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)

    # 3. Exclude inactive customers (no purchase history at all)
    if history[PURCHASE_DATA_SOURCE].count() == 0:
        return None

    # 4. Apply churn logic
    #    Churned = no purchases in the future window
    churned = 1 if future[PURCHASE_DATA_SOURCE].count() == 0 else 0

    return np.array([churned], dtype=np.float32)
Python
def churn_target_fn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Label a customer as churned (1) or active (0)."""

    # === Configuration ===
    TARGET_WINDOW_DAYS = 21          # How far into the future to look
    PURCHASE_DATA_SOURCE = "purchases"  # Name of your transaction data source

    # 1. Ensure the training window is long enough
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None

    # 2. Trim future events to the target window only
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    # 3. Exclude inactive customers (no purchase history at all)
    if history[PURCHASE_DATA_SOURCE].count() == 0:
        return None

    # 4. Apply churn logic
    churned = 1 if future[PURCHASE_DATA_SOURCE].count() == 0 else 0

    return np.array([churned], dtype=np.float32)

Step-by-Step Breakdown

① Validate the training window

Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
    return None
Python
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
    return None

During training, monad creates multiple temporal splits. Some splits may land too close to the end of the dataset, leaving less than TARGET_WINDOW_DAYS of future data. has_incomplete_training_window returns True in those cases so you can skip them. This check only applies during training — it is automatically bypassed at test/prediction time.

② Trim future events to the target window

Python
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
Python
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

future initially contains all events after the split. We narrow it down to exactly the window we care about (e.g., 21 days). This ensures the label reflects a consistent time horizon.

③ Exclude entities that should not be labeled

Python
if history[PURCHASE_DATA_SOURCE].count() == 0:
    return None

Customers with no purchase history cannot meaningfully churn — they were never active. Returning None excludes them from both training and evaluation.

Tip: Adjust this filter to match your business definition. For example, you might require at least 2 purchases: history["purchases"].count() < 2.

④ Apply the churn label

Python
churned = 1 if future[PURCHASE_DATA_SOURCE].count() == 0 else 0
return np.array([churned], dtype=np.float32)

If the customer made zero purchases in the target window, they are labeled as churned (1). Otherwise, they are active (0). The result must be a 1-D float32 NumPy array of size 1.


Training

Once the target function is defined, fine-tune a downstream model:

Python
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping

from monad.ui.module import load_from_foundation_model, BinaryClassificationTask

# Load foundation model and attach the churn task
module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=BinaryClassificationTask(),
    target_fn=churn_target_fn,
)

training_params = TrainingParams(
    checkpoint_dir=Path("./<this_model>"),
    learning_rate=1e-4,
    epochs=20,
    devices=[0],
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC", kwargs={"task": "binary"}),
        MetricParams(alias="auprc", metric_name="AveragePrecision", kwargs={"task": "binary"}),
        MetricParams(alias="recall", metric_name="Recall", kwargs={"task": "binary"}),
        MetricParams(alias="precision", metric_name="Precision", kwargs={"task": "binary"}),
    ],
    metric_to_monitor="val_auroc_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
    early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)

module.fit(training_params, seed=42)

Evaluation

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
    output_type=OutputType.DECODED,
    devices=[0],
    metrics=[
        MetricParams(alias="auroc", metric_name="AUROC"),
        MetricParams(alias="auprc", metric_name="AveragePrecision"),
        MetricParams(alias="recall", metric_name="Recall"),
    ],
)

results = module.test(testing_params)

Prediction

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    local_save_location=Path("./predictions.tsv"),
    output_type=OutputType.DECODED,
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    devices=[0],
)

predictions = module.predict(testing_params)

Variations

Below are alternative churn definitions. Swap them into the target function to match your business case.

Spending-decline churn

Label a customer as churned if their spending drops by 70 %+ compared to their historical average.

Python
def spending_decline_churn(
    history: Events, future: Events, attributes: Attributes, ctx: Dict
) -> np.ndarray | None:
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)

    history_spend = history["purchases"].sum(column="price")
    future_spend = future["purchases"].sum(column="price")

    if history_spend == 0:
        return None

    decline_ratio = future_spend / history_spend
    churned = 1 if decline_ratio < 0.3 else 0
    return np.array([churned], dtype=np.float32)
Python
def spending_decline_churn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    TARGET_WINDOW_DAYS = 21
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    history_spend = history["purchases"].sum(column="price")
    future_spend = future["purchases"].sum(column="price")

    if history_spend == 0:
        return None

    decline_ratio = future_spend / history_spend
    churned = 1 if decline_ratio < 0.3 else 0
    return np.array([churned], dtype=np.float32)

Multi-source activity churn

Consider a customer churned only if they have zero activity across all data sources — not just purchases.

Python
def activity_churn(
    history: Events, future: Events, attributes: Attributes, ctx: Dict
) -> np.ndarray | None:
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)

    if history["purchases"].count() == 0:
        return None

    total_future_activity = (
        future["purchases"].count()
        + future["page_views"].count()
        + future["logins"].count()
    )
    churned = 1 if total_future_activity == 0 else 0
    return np.array([churned], dtype=np.float32)
Python
def activity_churn(
    history: target_function.Events,
    future: target_function.Events,
    attributes: target_function.Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    TARGET_WINDOW_DAYS = 21
    target_window = timedelta(days=TARGET_WINDOW_DAYS)
    if target_function.has_incomplete_training_window(ctx, target_window):
        return None
    future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)

    if history["purchases"].count() == 0:
        return None

    total_future_activity = (
        future["purchases"].count()
        + future["page_views"].count()
        + future["logins"].count()
    )
    churned = 1 if total_future_activity == 0 else 0
    return np.array([churned], dtype=np.float32)

Metric Why it matters
AUROC Measures overall ranking quality — how well the model separates churners from non-churners.
AUPRC More informative than AUROC when churn is rare (imbalanced classes).
Recall Proportion of actual churners that the model catches. Prioritize this if missing a churner is costly.
Precision Proportion of predicted churners who truly churn. Prioritize this if retention actions are expensive.
F1 Score Harmonic mean of precision and recall — a single balanced metric.

Production Tips

  1. Tune your decision threshold. The default 0.5 is rarely optimal. Plot precision-recall curves and choose a threshold that matches your cost trade-off (e.g., a lower threshold catches more churners at the cost of more false alarms).

  2. Choose an appropriate target window. A 7-day window reacts fast but is noisy; a 90-day window is smoother but gives you less time to act. Start with 21–30 days and adjust based on your purchase cycle.

  3. Segment and validate. Check churn rates by customer segment (new vs. returning, high-value vs. low-value) to confirm the model behaves sensibly across cohorts.

  4. Retrain periodically. Churn patterns shift with seasons, promotions, and market changes. Retrain monthly or whenever model performance degrades.

  5. Use interpretability. Monad supports interpretability analysis to surface why a customer is predicted to churn — useful for designing targeted retention offers.