Skip to content

Classify Customer Spending Tier

Task type: MulticlassClassificationTask Industry: Retail

Loyalty programmes depend on accurately predicting which tier a customer will qualify for over the coming year. By forecasting spending tiers across four fiscal quarters, CRM teams can proactively offer tier-upgrade incentives, personalize rewards, and allocate retention budgets to customers on the cusp of a higher tier.

What makes this advanced? Fiscal quarter navigation + one-hot encoding — dynamically computes fiscal quarter boundaries, evaluates spending per quarter, returns one-hot encoded class vector.


Prerequisites

Before writing a target function you need:

  • A trained foundation model built on event data that includes the relevant data sources.
  • The monad library installed in your environment.
  • Data source(s): transactions with an amount column

Target Function

The target function tells monad how to label each entity for training. It receives four arguments:

Argument Type Description
history Events All events before the temporal split.
future Events All events after the temporal split.
attributes Attributes Static entity attributes.
ctx Dict Context dictionary containing SPLIT_TIMESTAMP, data mode, etc.

For multiclass classification, the function must return one of:

  • A 1-D float32 array of size num_classes — one-hot encoded tier assignment.
  • Noneexclude this entity from training.

Full Example

Python
import numpy as np
from datetime import timedelta
from typing import Dict

from monad.ui.target_function import Events, Attributes
from monad.ui.target_function import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window

from datetime import datetime, timezone

from monad.batch import TRAINING_END_TIMESTAMP

# === Configuration ===
CLASS_NAMES = ["NONE", "GOLD", "PLATINUM", "DIAMOND"]
GOLD_THRESHOLD = 1_000
PLATINUM_THRESHOLD = 5_000
DIAMOND_THRESHOLD = 10_000
TRANSACTION_DATA_SOURCE = "transactions"

def get_next_fiscal_quarter_start(dtm: datetime) -> datetime:
    """Find the start of the next fiscal quarter after dtm."""
    possible_starts = [
        datetime(year, month, 1, tzinfo=timezone.utc)
        for year in [dtm.year, dtm.year + 1]
        for month in [1, 4, 7, 10]
    ]
    return min(s for s in possible_starts if s > dtm)

def customer_spending_tier_target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: Dict,
) -> np.ndarray | None:
    """Classify customer into spending tier based on 4 quarters."""

    # 1. Ensure enough data for 4 full quarters
    expected_end = datetime.fromtimestamp(ctx[SPLIT_TIMESTAMP], tz=timezone.utc)
    for _ in range(5):
        expected_end = get_next_fiscal_quarter_start(expected_end)
    expected_end = expected_end - timedelta(days=1)

    if ctx[TRAINING_END_TIMESTAMP] < expected_end.timestamp():
        return None

    # 2. Evaluate spending in each of the next 4 quarters
    def get_tier(quarter_transactions):
        total = quarter_transactions.sum(column="amount")
        if total > DIAMOND_THRESHOLD:
            return 3
        elif total > PLATINUM_THRESHOLD:
            return 2
        elif total > GOLD_THRESHOLD:
            return 1
        return 0

    quarter_tiers = []
    quarter_start = datetime.fromtimestamp(ctx[SPLIT_TIMESTAMP], tz=timezone.utc)
    for _ in range(4):
        quarter_start = get_next_fiscal_quarter_start(quarter_start)
        quarter_end = get_next_fiscal_quarter_start(quarter_start) - timedelta(days=1)
        days = (quarter_end - quarter_start).days
        quarter_data = future.interval_from(
            quarter_start.timestamp(), timedelta(days=days)
        )[TRANSACTION_DATA_SOURCE]
        quarter_tiers.append(get_tier(quarter_data))

    # 3. Final tier = minimum across all quarters
    tier_idx = min(quarter_tiers)
    result = np.zeros(len(CLASS_NAMES), dtype=np.float32)
    result[tier_idx] = 1
    return result

Step-by-Step Breakdown

① Compute fiscal quarter boundaries

Python
def get_next_fiscal_quarter_start(dtm: datetime) -> datetime:
    possible_starts = [
        datetime(year, month, 1, tzinfo=timezone.utc)
        for year in [dtm.year, dtm.year + 1]
        for month in [1, 4, 7, 10]
    ]
    return min(s for s in possible_starts if s > dtm)

Generates all possible quarter start dates for the current and next year, then selects the earliest one after the given datetime. This handles year boundaries and any split timestamp position within a quarter.

② Validate data availability

Python
expected_end = datetime.fromtimestamp(ctx[SPLIT_TIMESTAMP], tz=timezone.utc)
for _ in range(5):
    expected_end = get_next_fiscal_quarter_start(expected_end)
expected_end = expected_end - timedelta(days=1)

if ctx[TRAINING_END_TIMESTAMP] < expected_end.timestamp():
    return None

The function needs 4 full fiscal quarters of future data. It advances 5 quarter boundaries (to cover 4 complete quarters) and checks that the training dataset extends far enough. If not, the sample is excluded.

③ Evaluate per-quarter spending

Python
quarter_data = future.interval_from(
    quarter_start.timestamp(), timedelta(days=days)
)[TRANSACTION_DATA_SOURCE]
quarter_tiers.append(get_tier(quarter_data))

Each quarter's transactions are sliced independently. The get_tier helper maps total spend to a tier index using the configured thresholds.

④ Return one-hot encoded tier

Python
tier_idx = min(quarter_tiers)
result = np.zeros(len(CLASS_NAMES), dtype=np.float32)
result[tier_idx] = 1
return result

The final tier is the minimum across all four quarters — the customer must maintain the tier consistently. The result is one-hot encoded as required by MulticlassClassificationTask.


Training

Once the target function is defined, fine-tune a downstream model:

Python
from pathlib import Path
from monad.ui.config import TrainingParams, MetricParams, MetricMonitoringMode
from monad.config.early_stopping import EarlyStopping

from monad.ui.module import load_from_foundation_model, MulticlassClassificationTask

module = load_from_foundation_model(
    checkpoint_path=Path("./foundation_model"),
    downstream_task=MulticlassClassificationTask(class_names=CLASS_NAMES),
    target_fn=customer_spending_tier_target_fn,
)

training_params = TrainingParams(
    checkpoint_dir=Path("./<this_model>"),
    learning_rate=1e-4,
    epochs=20,
    devices=[0],
    metrics=[
        MetricParams(alias="accuracy", metric_name="Accuracy", kwargs={"task": "multiclass", "num_classes": <num_classes>}),
        MetricParams(alias="f1_macro", metric_name="F1Score", kwargs={"task": "multiclass", "num_classes": <num_classes>, "average": "macro"}),
    ],
    metric_to_monitor="val_accuracy_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
    early_stopping=EarlyStopping(min_delta=1e-4, patience=5),
)

module.fit(training_params, seed=42)

Evaluation

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, MetricParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    prediction_date=datetime(2024, 5, 1, tzinfo=timezone.utc),
    output_type=OutputType.DECODED,
    devices=[0],
    metrics=[
        MetricParams(alias="accuracy", metric_name="Accuracy"),
        MetricParams(alias="f1_macro", metric_name="F1Score"),
    ],
)

results = module.test(testing_params)

Prediction

Python
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.ui.config import TestingParams, OutputType

module = load_from_checkpoint(Path("./<this_model>"))

testing_params = TestingParams(
    local_save_location=Path("./predictions.tsv"),
    output_type=OutputType.DECODED,
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    devices=[0],
)

predictions = module.predict(testing_params)

Metric Why it matters
Accuracy Overall proportion of correct predictions.
Macro F1 Balanced F1 across all classes.
Top-K Accuracy Whether the true class is in the top K predictions.
Confusion Matrix Reveals which classes are most often confused.

Production Tips

  1. Align fiscal quarters with your business calendar. The default uses Jan/Apr/Jul/Oct. If your fiscal year starts differently, adjust the get_next_fiscal_quarter_start logic.
  2. Consider using the average tier instead of minimum. Taking the minimum is conservative — a single weak quarter drops the customer. Use max or the mode if your loyalty programme is more lenient.
  3. Watch for class imbalance. Diamond-tier customers are typically rare. Use stratified sampling or class weighting to ensure the model learns to distinguish high-value tiers.
  4. Use predictions for proactive tier management. Customers predicted to fall from Platinum to Gold can receive targeted spend-incentive campaigns before the quarter ends.
  5. Validate threshold sensitivity. Small changes to the Gold/Platinum/Diamond thresholds can significantly shift the label distribution. Run sensitivity analysis before committing to thresholds.