Skip to content

Training Parameters

Parameters that control the foundation model training process. These are set in the training_params section of the YAML config or passed as a TrainingParams object in Python.

from monad.config import TrainingParams

TrainingParams

Parameter Type Default Description
epochs int 1 Number of epochs to train the model for.
learning_rate float 0.0001 Learning rate.
check_val_every_n_steps int \| None None Run validation every N training steps. Disables validation on epoch end when set.
check_val_every_n_epochs int \| None 1 Run validation every N epochs.
limit_train_batches int \| None None Limit number of training batches per epoch. Useful for quick validation of setup.
limit_val_batches int \| None None Limit number of validation batches per epoch.
loss Callable \| None None Custom loss function.
checkpoint_dir str \| Path \| None None Directory to store model checkpoints.
metric_to_monitor str \| None None Metric alias to monitor for best-model selection.
metric_monitoring_mode MetricMonitoringMode \| None None Whether to minimize ("min") or maximize ("max") the monitored metric.
gradient_clip_val int \| float \| None None Value above which gradients are clipped.
checkpoint_every_n_steps int \| None None Enable intra-epoch checkpointing at every N steps.
early_stopping EarlyStopping \| None None Early stopping configuration. See EarlyStopping below.
precision Literal[...] "bf16-mixed" / "16-mixed" Float precision for training. Defaults to "bf16-mixed" on GPUs with bfloat16 support, otherwise "16-mixed". See Precision Values below.

Inherited Parameters

These are inherited from the base parameter class and shared with TestingParams.

Parameter Type Default Description
devices list[int] \| int \| "auto" "auto" GPU devices to use. "auto" selects the least-occupied GPU automatically, falling back to CPU if none are available. A positive int specifies count; a list[int] specifies device indices; -1 uses all available GPUs.
accelerator "cpu" \| "gpu" "gpu" Accelerator type.
strategy str \| None None Distributed training strategy: None (PyTorch Lightning default), "auto", "ddp", "fsdp", or "fsdp:A:B" (A = data parallelism, B = tensor parallelism).
nccl_timeout timedelta \| None None Timeout for NCCL collective operations in distributed training (DDP/FSDP). A bare number is interpreted as seconds. Ignored on a single device.
rank_sync_timeout timedelta \| None None Timeout for a dedicated per-step rank-synchronization barrier that absorbs data-loading skew between ranks, independently of nccl_timeout (which then only needs to cover the gradient sync). A bare number is interpreted as seconds. Ignored on a single device.

Note

Parameters such as metrics, top_k, predictions_threshold, entity_ids, callbacks, and approximate_decoding_params are inherited from the base class but not applicable to foundation model training. Custom metrics are explicitly rejected during pretraining. These parameters are available for Scenario Model Training.

EarlyStopping

Wraps the configuration for PyTorch Lightning's early stopping callback.

from monad.config import EarlyStopping

early_stopping = EarlyStopping(
    min_delta=0.001,
    patience=5,
    verbose=True,
)
Parameter Type Default Description
min_delta float 0.0 Minimum change in the monitored metric to qualify as an improvement.
patience int 3 Number of validation checks with no improvement after which training stops.
verbose bool False Whether to log information about registered improvements.

Precision Values

Valid values for the precision parameter:

Value Description
32 or "32" or "32-true" Full 32-bit float precision.
64 or "64" or "64-true" Full 64-bit float precision.
16 or "16" or "16-true" Pure 16-bit float precision.
"16-mixed" Mixed precision with float16.
"bf16" or "bf16-true" Pure bfloat16 precision.
"bf16-mixed" Mixed precision with bfloat16. Default on supported GPUs.

Tip

"bf16-mixed" is the default on GPUs that support bfloat16 (e.g., A100, H100). On older GPUs, it falls back to "16-mixed". Mixed precision provides a good balance between training speed and numerical stability.

YAML Example

In the foundation model config file:

training_params:
  learning_rate: 0.0003
  epochs: 3
  precision: "bf16-mixed"
  strategy: "ddp"
  devices: [0, 1]
  early_stopping:
    min_delta: 0.001
    patience: 5

Python Example

from monad.config import TrainingParams, EarlyStopping

training_params = TrainingParams(
    epochs=3,
    learning_rate=0.0003,
    precision="bf16-mixed",
    devices=[0, 1],
    strategy="ddp",
    early_stopping=EarlyStopping(
        min_delta=0.001,
        patience=5,
    ),
)