Skip to content

Target Function Reference

The target function defines what a scenario model predicts. It receives historical and future events for each entity and returns a label, value, or sketch — or None to exclude the entity.

Signature

import numpy as np
from monad.targets import Events, Attributes

def target_fn(
    history: Events,
    future: Events,
    attributes: Attributes,
    ctx: dict,
) -> np.ndarray | Sketch | None:
    ...
Argument Type Description
history Events All events before the split point. This is what the model sees as input.
future Events All events after the split point. Use this to compute the prediction target.
attributes Attributes Static entity properties (e.g., customer demographics).
ctx dict Context dictionary with SPLIT_TIMESTAMP, MODE, TRAINING_END_TIMESTAMP.

Return Types by Task

Task Return Type Example
Binary Classification np.array([0 or 1], dtype=np.float32) np.array([1.0], dtype=np.float32)
Multiclass Classification np.array([class_index], dtype=np.float32) np.array([2.0], dtype=np.float32)
Multilabel Classification np.array([0,1,0,...], dtype=np.float32) np.array([1,0,1], dtype=np.float32)
Regression np.array([value], dtype=np.float32) np.array([1500.0], dtype=np.float32)
Recommendation Sketch sketch(items, weights)
Skip entity None return None

Note

Return None when an entity should be excluded from training — e.g., insufficient history, no future events, incomplete time windows, or invalid entities (test accounts, bots).


Events API

Access event data sources by name:

txns = history["transactions"]   # Returns DataSourceEvents
txns = future["transactions"]

Aggregation Methods

Methods on DataSourceEvents:

Method Return Type Description
.count() int Number of events.
.sum(column, ignore_nan=True) float Sum of values. column: name or callable.
.mean(column, ignore_nan=True) float Mean of values. column: name or callable.
.min(column, ignore_nan=True) float Minimum value. column: name or callable.
.max(column, ignore_nan=True) float Maximum value. column: name or callable.

Lambda expressions are also supported for computed aggregations:

total = txns.sum(lambda data: data["price"] * data["quantity"])

Column Access

# Access column data (returns ModalityEvents)
product_ids = txns["product_id"]
product_ids.events     # np.ndarray of values
product_ids.timestamps # np.ndarray of timestamps

# Timestamps for the data source
txns.timestamps  # np.ndarray of unix timestamps

Extra Columns

Columns declared in extra_columns (not used as features) are accessible via .extra:

order_id = txns.extra["order_id"]  # np.ndarray

Filtering

# Filter by column value with lambda
app_txns = txns.filter("channel", lambda x: x == "APP")
expensive = txns.filter("price", lambda x: x > 100)

# Filter using a callable expression
high_value = txns.filter(
    lambda data: data["price"] * data["quantity"],
    lambda x: x > 500,
)

# Filter by exact match (single value or list)
app_txns = txns.where_eq("channel", "APP")
app_web = txns.where_eq("channel", ["APP", "WEB"])

The by parameter accepts a column name (str) or a callable that computes values from the data mapping.

Time Windows

from datetime import timedelta

# Future window: 30 days from split point
future_30d = future.interval_from(ctx[SPLIT_TIMESTAMP], timedelta(days=30))

# History window: last 90 days before split
recent = history.interval_from(ctx[SPLIT_TIMESTAMP], timedelta(days=-90))

# Explicit time range
window = history.interval_between(
    start=ctx[SPLIT_TIMESTAMP] - 30 * 86400,
    end=ctx[SPLIT_TIMESTAMP],
    include="start",  # "start" or "end"
)

Note

interval_from and interval_between are available on both Events (all sources) and DataSourceEvents (single source).

Last Basket

Get events from the most recent timestamp:

last_basket = history["transactions"].get_last_basket()

GroupBy

grouped = txns.groupBy("category")
# or multiple columns:
grouped = txns.groupBy(["category", "channel"])

Methods on EventsGroupBy:

Method Returns Description
.count() (np.ndarray, list[str]) Count per group and group names.
.sum(target) (np.ndarray, list[str]) Sum per group and group names. target: column name or callable. Optional ignore_nan=True.
.mean(target) (np.ndarray, list[str]) Mean per group and group names. target: column name or callable. Optional ignore_nan=True.
.min(target) (np.ndarray, list[str]) Minimum per group and group names. target: column name or callable. Optional ignore_nan=True.
.max(target) (np.ndarray, list[str]) Maximum per group and group names. target: column name or callable. Optional ignore_nan=True.
.exists(groups=[...]) (np.ndarray, list[str]) Existence flags per group and group names.
.apply(func, ...) (np.ndarray, list[str]) Custom aggregation per group.
# Get counts by category
counts, names = txns.groupBy("category").count()

# With explicit group ordering
counts, names = txns.groupBy("category").count(
    groups=["Electronics", "Fashion", "Home"]
)

# Check existence
exists, names = txns.groupBy("category").exists(
    groups=["Electronics", "Fashion", "Home"]
)

# Custom aggregation
medians, names = txns.groupBy("category").apply(
    func=np.median,
    default_value=0.0,
    target="price",
    groups=["Electronics", "Fashion"],
)

Attributes API

Access entity attributes by data source name and column:

# Access a specific attribute value
segment = attributes["customers"]["segment"].attribute  # str or numeric value
age = attributes["customers"]["age"].attribute

# Extra columns on attributes
signup_source = attributes["customers"].extra["signup_source"]

Properties of ModalityAttribute:

Property Type Description
.attribute Any The actual attribute value.
.column_name str Column name.
.dataset_name str Data source name.

Context Dictionary

from monad.batch import SPLIT_TIMESTAMP, MODE, TRAINING_END_TIMESTAMP
from monad.config import DataMode
Key Type Description
SPLIT_TIMESTAMP float Unix timestamp of the split point dividing history and future.
MODE DataMode Current processing mode: DataMode.TRAIN, DataMode.VALIDATION, DataMode.TEST, or DataMode.PREDICT.
TRAINING_END_TIMESTAMP float Unix timestamp marking the end of the training period.
split_time = ctx[SPLIT_TIMESTAMP]
if ctx[MODE] == DataMode.TRAIN:
    # Training-specific logic
training_end = ctx[TRAINING_END_TIMESTAMP]

Helper Functions

has_incomplete_training_window()

Check if the training window is too short for the required target window.

from monad.ui.target_function import has_incomplete_training_window
from datetime import timedelta

if has_incomplete_training_window(ctx, timedelta(days=30)):
    return None  # Skip — can't observe full future window
Parameter Type Description
ctx dict The context dictionary.
required_length timedelta Minimum required future window length.

verify_target()

Validate a target function against real data before training.

from monad.ui.target_function import verify_target
from monad.ui.module import BinaryClassificationTask

results = verify_target(
    target_fn=my_target_fn,
    fm_checkpoint_path="./foundation_model",
    task=BinaryClassificationTask(),
    num_percentage_entities=1,  # Test on 1% of entities
)
Parameter Type Default Description
target_fn TargetFunction required The target function to validate.
fm_checkpoint_path str \| Path required Foundation model checkpoint path.
task Task required Task type matching the target function.
data_params_overrides DataParams \| None None Override data parameters for validation.
num_percentage_entities int 1 Percentage of entities to evaluate on.
percentage_nones_allowed int 90 Maximum percentage of None returns allowed.
log_every_n_steps int \| None None Logging frequency.
limit int \| None None Maximum number of entities to evaluate.

Sketch Functions (Recommendations)

For recommendation tasks, the target function returns a Sketch object representing items and their weights.

from monad.ui.target_function import sketch, sequential_decay, sketch_filtering_mask
from monad.targets.recommendation import time_decay

sketch()

Create a sketch from items and weights.

items = future["transactions"]["product_id"]
weights = np.ones(len(items), dtype=np.float32)
return sketch(items, weights)

sequential_decay()

Compute position-based decay weights. Earlier events in the future receive higher weight.

weights = sequential_decay(future["transactions"], gamma=0.5)
Parameter Type Default Description
events DataSourceEvents required The events to compute weights for.
gamma float 0.0 Decay factor. 0.0 = uniform weights, 1.0 = full decay (only first event matters).
init_weights ModalityEvents \| None None Initial weights to scale by.

time_decay()

Compute time-based decay weights. Events closer in time to the split point receive higher weight.

weights = time_decay(future["transactions"], daily_decay=0.1)
Parameter Type Default Description
events DataSourceEvents required The events to compute weights for.
daily_decay float 0.0 Percentage weight decay per day.
init_weights ModalityEvents \| None None Initial weights to scale by.

sketch_filtering_mask()

Create a mask to exclude items the entity has already interacted with (e.g., already purchased products).

# Exclude previously purchased items from recommendations
mask = sketch_filtering_mask(history["transactions"]["product_id"])
future_sketch = sketch(future_items, weights)
return (future_sketch, mask)  # Return tuple for filtered recommendations

Complete Examples

Binary Classification — Churn Prediction

import numpy as np
from datetime import timedelta
from monad.targets import Events, Attributes
from monad.batch import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window

TARGET_WINDOW_DAYS = 30

def churn_target_fn(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() < 2:
        return None
    if has_incomplete_training_window(ctx, timedelta(days=TARGET_WINDOW_DAYS)):
        return None

    future_window = future.interval_from(ctx[SPLIT_TIMESTAMP], timedelta(days=TARGET_WINDOW_DAYS))
    churned = 1 if future_window["transactions"].count() == 0 else 0
    return np.array([churned], dtype=np.float32)

Multiclass Classification — Next Category

import numpy as np
from monad.targets import Events, Attributes

CATEGORIES = ["Electronics", "Fashion", "Home", "Sports", "Beauty"]

def next_category_target(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() < 2:
        return None
    if future["transactions"].count() == 0:
        return None

    first_category = future["transactions"]["category"].events[0]
    if first_category not in CATEGORIES:
        return None
    return np.array([CATEGORIES.index(first_category)], dtype=np.float32)

Regression — Customer Lifetime Value

import numpy as np
from datetime import timedelta
from monad.targets import Events, Attributes
from monad.batch import SPLIT_TIMESTAMP
from monad.ui.target_function import has_incomplete_training_window

def ltv_target(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() < 3:
        return None
    if has_incomplete_training_window(ctx, timedelta(days=30)):
        return None

    future_30d = future.interval_from(ctx[SPLIT_TIMESTAMP], timedelta(days=30))
    total_spend = future_30d["transactions"].sum("price")
    return np.array([float(total_spend)], dtype=np.float32)

Recommendation — Product Recommendations

from monad.targets import Events, Attributes
from monad.ui.target_function import sketch, sequential_decay

def products_target(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() == 0:
        return None
    if future["transactions"].count() == 0:
        return None

    future_txns = future["transactions"]
    weights = sequential_decay(future_txns, gamma=0.5)
    return sketch(future_txns["product_id"], weights)

Recommendation — With Filtering Mask

from monad.targets import Events, Attributes
from monad.ui.target_function import sketch, sequential_decay, sketch_filtering_mask

def filtered_products_target(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() == 0:
        return None
    if future["transactions"].count() == 0:
        return None

    future_txns = future["transactions"]
    weights = sequential_decay(future_txns, gamma=0.5)

    # Exclude previously purchased items
    mask = sketch_filtering_mask(history["transactions"]["product_id"])
    return (sketch(future_txns["product_id"], weights), mask)

Recommendation — Train vs Eval Behavior

from monad.targets import Events, Attributes
from monad.batch import MODE
from monad.config import DataMode
from monad.ui.target_function import sketch, sequential_decay

def next_basket_target(history: Events, future: Events, attributes: Attributes, ctx: dict):
    if history["transactions"].count() == 0:
        return None
    if future["transactions"].count() == 0:
        return None

    future_txns = future["transactions"]

    if ctx[MODE] == DataMode.TRAIN:
        # During training: consider all future purchases equally
        weights = sequential_decay(future_txns, gamma=0)
    else:
        # During eval: focus on the next purchase only
        weights = sequential_decay(future_txns, gamma=1)

    return sketch(future_txns["product_id"], weights)