Customer Churn Prediction
Task type: BinaryClassificationTask
Industry: Retail / E-commerce
Churn prediction identifies customers who are likely to stop buying within a defined future window. This lets your team act early — trigger retention campaigns, offer incentives, or flag accounts for customer success outreach.
What does "churn" mean here? A customer is considered churned if they make zero purchases in the next N days. You control the window length.
Prerequisites
Before writing a target function you need:
- A trained foundation model built on event data that includes customer transactions (e.g., a
purchasesdata source). - The monad library installed in your environment (for Python App).
Target Function
The target function tells the model how to label each customer for training. It receives four arguments:
| Argument | Type | Description |
|---|---|---|
history |
Events |
All events before the temporal split. |
future |
Events |
All events after the temporal split. |
attributes |
Attributes |
Static entity attributes. |
ctx |
Dict |
Context dictionary containing SPLIT_TIMESTAMP, data mode, etc. |
The function must return one of:
np.array([1], dtype=np.float32)— customer churnednp.array([0], dtype=np.float32)— customer activeNone— exclude this customer from training (e.g., insufficient history)
Full Example
import numpy as np
from datetime import timedelta
from typing import Dict
from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window
# === Configuration ===
TARGET_WINDOW_DAYS = 21 # How far into the future to look
PURCHASE_DATA_SOURCE = "purchases" # Name of your transaction data source
def churn_target_fn(
history: Events,
future: Events,
attributes: Attributes,
ctx: Dict,
) -> np.ndarray | None:
"""Label a customer as churned (1) or active (0)."""
# 1. Ensure the training window is long enough
# If the split is too close to the end of available data,
# we cannot observe a full target window — skip this sample.
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
return None
# 2. Trim future events to the target window only
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
# 3. Exclude inactive customers (no purchase history at all)
if history[PURCHASE_DATA_SOURCE].count() == 0:
return None
# 4. Apply churn logic
# Churned = no purchases in the future window
churned = 1 if future[PURCHASE_DATA_SOURCE].count() == 0 else 0
return np.array([churned], dtype=np.float32)
def churn_target_fn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
"""Label a customer as churned (1) or active (0)."""
# === Configuration ===
TARGET_WINDOW_DAYS = 21 # How far into the future to look
PURCHASE_DATA_SOURCE = "purchases" # Name of your transaction data source
# 1. Ensure the training window is long enough
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
# 2. Trim future events to the target window only
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
# 3. Exclude inactive customers (no purchase history at all)
if history[PURCHASE_DATA_SOURCE].count() == 0:
return None
# 4. Apply churn logic
churned = 1 if future[PURCHASE_DATA_SOURCE].count() == 0 else 0
return np.array([churned], dtype=np.float32)
Step-by-Step Breakdown
① Validate the training window
During training, monad creates multiple temporal splits. Some splits may land too close to the end of the dataset, leaving less than TARGET_WINDOW_DAYS of future data. has_incomplete_training_window returns True in those cases so you can skip them. This check only applies during training — it is automatically bypassed at test/prediction time.
② Trim future events to the target window
future initially contains all events after the split. We narrow it down to exactly the window we care about (e.g., 21 days). This ensures the label reflects a consistent time horizon.
③ Exclude entities that should not be labeled
Customers with no purchase history cannot meaningfully churn — they were never active. Returning None excludes them from both training and evaluation.
Tip: Adjust this filter to match your business definition. For example, you might require at least 2 purchases:
history["purchases"].count() < 2.
④ Apply the churn label
churned = 1 if future[PURCHASE_DATA_SOURCE].count() == 0 else 0
return np.array([churned], dtype=np.float32)
If the customer made zero purchases in the target window, they are labeled as churned (1). Otherwise, they are active (0). The result must be a 1-D float32 NumPy array of size 1.
Training
Once the target function is defined, fine-tune a downstream model:
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping
from monad.ui.module import load_from_foundation_model, BinaryClassificationTask
# Load foundation model and attach the churn task
module = load_from_foundation_model(
checkpoint_path=Path("./foundation_model"),
downstream_task=BinaryClassificationTask(),
target_fn=churn_target_fn,
)
training_params = TrainingParams(
checkpoint_dir=Path("./<this_model>"),
learning_rate=1e-4,
epochs=20,
devices=[0],
metrics=[
MetricParams(alias="auroc", metric_name="AUROC", kwargs={"task": "binary"}),
MetricParams(alias="auprc", metric_name="AveragePrecision", kwargs={"task": "binary"}),
MetricParams(alias="recall", metric_name="Recall", kwargs={"task": "binary"}),
MetricParams(alias="precision", metric_name="Precision", kwargs={"task": "binary"}),
],
metric_to_monitor="val_auroc_0",
metric_monitoring_mode=MetricMonitoringMode.MAX,
early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)
module.fit(training_params, seed=42)
Evaluation
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
output_type=OutputType.DECODED,
devices=[0],
metrics=[
MetricParams(alias="auroc", metric_name="AUROC"),
MetricParams(alias="auprc", metric_name="AveragePrecision"),
MetricParams(alias="recall", metric_name="Recall"),
],
)
results = module.test(testing_params)
Prediction
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType
module = load_from_checkpoint(Path("./<this_model>"))
testing_params = TestingParams(
local_save_location=Path("./predictions.tsv"),
output_type=OutputType.DECODED,
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
devices=[0],
)
predictions = module.predict(testing_params)
Variations
Below are alternative churn definitions. Swap them into the target function to match your business case.
Spending-decline churn
Label a customer as churned if their spending drops by 70 %+ compared to their historical average.
def spending_decline_churn(
history: Events, future: Events, attributes: Attributes, ctx: Dict
) -> np.ndarray | None:
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
history_spend = history["purchases"].sum(column="price")
future_spend = future["purchases"].sum(column="price")
if history_spend == 0:
return None
decline_ratio = future_spend / history_spend
churned = 1 if decline_ratio < 0.3 else 0
return np.array([churned], dtype=np.float32)
def spending_decline_churn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
TARGET_WINDOW_DAYS = 21
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
history_spend = history["purchases"].sum(column="price")
future_spend = future["purchases"].sum(column="price")
if history_spend == 0:
return None
decline_ratio = future_spend / history_spend
churned = 1 if decline_ratio < 0.3 else 0
return np.array([churned], dtype=np.float32)
Multi-source activity churn
Consider a customer churned only if they have zero activity across all data sources — not just purchases.
def activity_churn(
history: Events, future: Events, attributes: Attributes, ctx: Dict
) -> np.ndarray | None:
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[SPLIT_TIMESTAMP], target_window)
if history["purchases"].count() == 0:
return None
total_future_activity = (
future["purchases"].count()
+ future["page_views"].count()
+ future["logins"].count()
)
churned = 1 if total_future_activity == 0 else 0
return np.array([churned], dtype=np.float32)
def activity_churn(
history: target_function.Events,
future: target_function.Events,
attributes: target_function.Attributes,
ctx: Dict,
) -> np.ndarray | None:
TARGET_WINDOW_DAYS = 21
target_window = timedelta(days=TARGET_WINDOW_DAYS)
if target_function.has_incomplete_training_window(ctx, target_window):
return None
future = future.interval_from(ctx[target_function.SPLIT_TIMESTAMP], target_window)
if history["purchases"].count() == 0:
return None
total_future_activity = (
future["purchases"].count()
+ future["page_views"].count()
+ future["logins"].count()
)
churned = 1 if total_future_activity == 0 else 0
return np.array([churned], dtype=np.float32)
Recommended Metrics
| Metric | Why it matters |
|---|---|
| AUROC | Measures overall ranking quality — how well the model separates churners from non-churners. |
| AUPRC | More informative than AUROC when churn is rare (imbalanced classes). |
| Recall | Proportion of actual churners that the model catches. Prioritize this if missing a churner is costly. |
| Precision | Proportion of predicted churners who truly churn. Prioritize this if retention actions are expensive. |
| F1 Score | Harmonic mean of precision and recall — a single balanced metric. |
Production Tips
-
Tune your decision threshold. The default 0.5 is rarely optimal. Plot precision-recall curves and choose a threshold that matches your cost trade-off (e.g., a lower threshold catches more churners at the cost of more false alarms).
-
Choose an appropriate target window. A 7-day window reacts fast but is noisy; a 90-day window is smoother but gives you less time to act. Start with 21–30 days and adjust based on your purchase cycle.
-
Segment and validate. Check churn rates by customer segment (new vs. returning, high-value vs. low-value) to confirm the model behaves sensibly across cohorts.
-
Retrain periodically. Churn patterns shift with seasons, promotions, and market changes. Retrain monthly or whenever model performance degrades.
-
Use interpretability. Monad supports interpretability analysis to surface why a customer is predicted to churn — useful for designing targeted retention offers.