Skip to content

YAML Configuration Reference

The foundation model is configured entirely via a single YAML file. This file is passed to pretrain() as the config_path parameter.

Top-Level Structure

data_sources: [...]                # Required — your data tables
data_params: {...}                 # Required — date ranges and splits
data_loader_params: {...}          # Optional — batch loading settings
training_params: {...}             # Optional — training hyperparameters
memory_constraining_params: {...}  # Optional — model size and memory limits
query_optimization: {...}          # Optional — query parallelization
calibration_params: {...}          # Optional — automatic DataLoader tuning

data_sources

A list of data source definitions. At least one event data source is required.

Data Source Fields

Field Type Required Description
type str Yes Source type: "event", "main_entity_attribute", or "attribute".
name str Yes Unique name for this data source. Referenced in target functions and joins.
data_location object Yes Database type and connection parameters. See below.
main_entity_column str Yes (event, main_entity_attribute) Column identifying the main entity (e.g., customer_id).
date_column object Yes (event only) Timestamp column configuration. See Date Column Formats.
allowed_columns list[str] No Only these columns will be used. Mutually exclusive with disallowed_columns.
disallowed_columns list[str] No These columns will be excluded. Auto-populated with bijection columns in suggested_config.yaml.
column_type_overrides dict No Override inferred column types (e.g., {price_ts: time_series}). Auto-populated with time-series candidates in suggested_config.yaml.
joined_data_sources list No Attribute tables to join. See Joining Tables.
sql_lambdas list No SQL-based computed columns.
where_condition str No SQL WHERE clause to filter rows at query time.
shared_entities list No Map entity columns across data sources for shared representations.
partition_column str No Column used for data partitioning (for query optimization).
num_groups int No Number of partition groups.

data_location

Field Type Description
database_type str One of: parquet, snowflake, bigquery, databricks, hive, clickhouse, synapse.
connection_params object Database-specific parameters. See Data Connectors.
table_name str Table name in the database.
schema_name str Schema name (required for some databases).

date_column

Field Type Description
name str Column name containing the event timestamp.
format str Date format string (e.g., "%Y-%m-%d", "unix").

Source Type: event

Timestamped behavioral data. Multiple event sources are supported. Each row represents a single event for an entity.

- type: event
  name: transactions
  data_location:
    database_type: parquet
    connection_params:
      path: "/data/transactions.parquet"
      cache_path: "/basemodel/db_cache/"
    table_name: transactions
  main_entity_column: customer_id
  date_column:
    name: t_dat
    format: "%Y-%m-%d"
  disallowed_columns: ["order_id"]
  joined_data_sources:
    - name: articles
      join_on:
        - [article_id, article_id]

Source Type: main_entity_attribute

Static or slowly-changing entity properties (one row per entity). Optional.

- type: main_entity_attribute
  name: customers
  data_location:
    database_type: parquet
    connection_params:
      path: "/data/customers.parquet"
      cache_path: "/basemodel/db_cache/"
    table_name: customers
  main_entity_column: customer_id

Source Type: attribute

Dimension tables used in joins with event data sources. Optional.

- type: attribute
  name: articles
  allowed_columns: ["product_type_name", "product_group_name",
                    "department_name", "section_name",
                    "colour_group_name", "perceived_colour_master_name"]
  data_location:
    database_type: parquet
    connection_params:
      path: "/data/articles.parquet"
      cache_path: "/basemodel/db_cache/"
    table_name: articles

SQL Lambdas

Compute derived columns using SQL expressions:

sql_lambdas:
  - alias: price_time_series
    expression: price
  - alias: total_price
    expression: "CAST({{ resolve_fn('price') }} * {{ resolve_fn('quantity') }} AS FLOAT)"

data_params

Controls the data range and how data is split for training/validation/testing.

Field Type Required Description
data_start_date str Yes Earliest timestamp to include (e.g., "2018-09-20 00:00:00").
split object Yes Split configuration. See below.

Entity Split

Splits entities randomly into training and validation sets. Recommended for most use cases.

data_params:
  data_start_date: "2018-09-20 00:00:00"
  split:
    type: entity
    training: 90
    validation: 10
    training_validation_end: "2020-09-04 00:00:00"
    test:
      start_date: "2020-09-05 00:00:00"
      end_date: "2020-09-22 00:00:00"
Field Type Description
type "entity" Entity-based split.
training float Percentage of entities for training (0.01–99.99). Accepts whole numbers (90) or fractional values (90.5).
validation float Percentage of entities for validation (0.01–99.99). training + validation must be ≤ 100.
training_validation_end str End date for training and validation data.
test.start_date str Start of the test window.
test.end_date str End of the test window.

Time Split

Splits data by time ranges. Each mode has a specific date window.

data_params:
  data_start_date: "2018-09-20 00:00:00"
  split:
    type: time
    training:
      start_date: "2018-09-20 00:00:00"
      end_date: "2020-06-30 00:00:00"
    validation:
      start_date: "2020-07-01 00:00:00"
      end_date: "2020-09-04 00:00:00"
    test:
      start_date: "2020-09-05 00:00:00"
      end_date: "2020-09-22 00:00:00"

Extra Columns

Columns available in the target function via .extra but not used as model features:

data_params:
  extra_columns:
    - data_source_name: transactions
      columns:
        - order_id

data_loader_params

Controls how data is batched and loaded during training.

Field Type Default Description
batch_size int 256 Number of entities per batch. Increase for faster training (watch GPU memory).
val_batch_size int \| None None Batch size for validation. Defaults to batch_size if not set. Increase for faster validation without affecting training memory.
num_workers int 5 Number of data-loading worker processes. Increase for faster data throughput.
pin_memory bool False Whether to use pinned memory for faster CPU-to-GPU transfer.
drop_last bool False Whether to drop the last incomplete batch.
prefetch_factor int \| None 2 Number of batches to prefetch per worker.
pin_memory_device str "" Device to pin memory to (e.g., "cuda"). Only relevant when pin_memory=True.
worker_init_fn Callable \| None None Custom initialization function called in each data-loading worker.
data_loader_params:
  batch_size: 256
  num_workers: 4

training_params

Training hyperparameters. For the full reference, see Training Parameters.

training_params:
  learning_rate: 0.0003
  epochs: 1
  strategy: "ddp"
  devices: [0, 1]            # Default is "auto" (picks least-occupied GPU)
  precision: "bf16-mixed"
  limit_train_batches: 5    # Remove after validating setup
  limit_val_batches: 5      # Remove after validating setup

memory_constraining_params

Control model size and memory usage.

Field Type Default Description
hidden_dim int 2048 Model hidden dimension. Reduce if GPU memory is insufficient.
num_layers int 4 Number of hidden layers in the transformer model.
emde_quality float 1.0 Feature density estimation quality. Lower values reduce memory at cost of accuracy.
memory_constraining_params:
  hidden_dim: 2048
  num_layers: 4

Tip

If you encounter CUDA Out of Memory errors, try reducing hidden_dim to 1024 or 512, or reducing batch_size in data_loader_params.


query_optimization

Optimize database query performance for large datasets.

Field Type Default Description
cleora_num_query_chunks int 1 Split the fit/embedding (Cleora) query into N chunks for parallelization. Renamed from num_query_chunks.
data_loading_num_query_chunks int 1 Split the data-loading query (train, validation, test, predict) into N chunks. Mid-epoch resume is unsupported when greater than 1.
num_cpus int 4 Number of CPU cores for parallel query execution.
num_concurrent_features int 4 Number of feature columns processed concurrently.
sampling_params SamplingParams SamplingParams() Entity and history sampling limits for large datasets. Fields: num_entities, history_limit.
query_optimization:
  cleora_num_query_chunks: 4
  data_loading_num_query_chunks: 1
  num_cpus: 8

Breaking change in 1.7

num_query_chunks was removed. Configurations that set it must migrate to cleora_num_query_chunks (fit phase) and data_loading_num_query_chunks (train/validation/test/predict). Unknown fields are rejected, so an old configuration will fail validation until updated.


calibration_params

Automatically finds the optimal DataLoader num_workers and prefetch_factor before foundation model training. When enabled, BaseModel benchmarks multiple configurations and applies the most efficient one.

Field Type Default Description
enabled bool false Enable DataLoader calibration before training.
candidate_workers list[int] [0,2,4,6,8,10,12,16,32] Worker counts to benchmark. 0 is always tested as baseline.
prefetch_factors list[int] [2, 4] Prefetch factor values to sweep for non-zero workers.
warmup_batches int 10 Batches to skip before measuring (JIT/cache warm-up).
timeout_seconds float 20.0 Max wall-clock seconds per configuration.
max_measure_batches int 200 Max batches to measure per configuration.
cap_workers_by_ram bool true Estimate per-worker buffer memory and cap workers to avoid OOM.
ram_safety_margin float 0.2 Fraction of available RAM to keep free when computing the worker cap.
memory_probe_count int 10000 Max unique entities to sample for memory estimation.
cpu_margin float 0.2 Fraction of CPUs to reserve (e.g., 0.2 = use 80 % for workers).
efficient_threshold float 0.90 Selects cheapest config achieving ≥ 90 % of peak throughput.
plateau_threshold float 0.05 Relative improvement below which a config is considered flat (early stopping).
plateau_min_configs int 3 Min successful configs before plateau stopping can trigger.
max_consecutive_failures int 3 Abort sweep after this many consecutive failures.
seed int 42 Seed for reproducibility of calibration runs.
calibration_params:
  enabled: true
  candidate_workers: [0, 4, 6, 8, 12, 14]
  prefetch_factors: [1, 2]
  warmup_batches: 10
  timeout_seconds: 15.0
  cap_workers_by_ram: true
  ram_safety_margin: 0.2
  memory_probe_count: 20000

Tip

Start with just enabled: true — the defaults work well for most setups. Customize candidate_workers only if you know your hardware limits, and adjust ram_safety_margin if you need tighter or looser memory headroom.


Complete Example

A full, annotated configuration using Parquet with joins and entity split:

# ---- Data Sources ----
data_sources:

  # Events (mandatory, at least one)
  - type: event
    name: transactions
    data_location:
      database_type: parquet
      connection_params:
        path: "/data/transactions.parquet"
        cache_path: "/basemodel/db_cache/"
      table_name: transactions
    main_entity_column: customer_id
    date_column:
      name: t_dat
      format: "%Y-%m-%d"
    disallowed_columns: ["order_id"]
    sql_lambdas:
      - alias: price_time_series
        expression: price
    column_type_overrides:
      price_time_series: time_series
    joined_data_sources:
      - name: articles
        join_on:
          - [article_id, article_id]

  # Main entity attributes (optional)
  - type: main_entity_attribute
    name: customers
    data_location:
      database_type: parquet
      connection_params:
        path: "/data/customers.parquet"
        cache_path: "/basemodel/db_cache/"
      table_name: customers
    main_entity_column: customer_id

  # Attribute / dimension table (optional)
  - type: attribute
    name: articles
    allowed_columns: ["product_type_name", "product_group_name",
                      "department_name", "section_name",
                      "colour_group_name", "perceived_colour_master_name"]
    data_location:
      database_type: parquet
      connection_params:
        path: "/data/articles.parquet"
        cache_path: "/basemodel/db_cache/"
      table_name: articles

# ---- Data Parameters ----
data_params:
  data_start_date: "2018-09-20 00:00:00"
  split:
    type: entity
    training: 90
    validation: 10
    training_validation_end: "2020-09-04 00:00:00"
    test:
      start_date: "2020-09-05 00:00:00"
      end_date: "2020-09-22 00:00:00"

# ---- Data Loader ----
data_loader_params:
  batch_size: 256
  num_workers: 4

# ---- Training ----
training_params:
  learning_rate: 0.0003
  epochs: 3
  precision: "bf16-mixed"
  strategy: "ddp"
  devices: [0, 1]

# ---- Memory Constraints ----
memory_constraining_params:
  hidden_dim: 2048

# ---- DataLoader Calibration (optional) ----
calibration_params:
  enabled: true