Distributed Training
By default, training runs on a single GPU. For larger datasets or faster iteration, you can distribute training across multiple GPUs using PyTorch Lightning's distributed strategies.
Multi-GPU with DDP
The most common setup is DDP (Distributed Data Parallel), where each GPU processes a different subset of the data in parallel:
from monad.ui.config import TrainingParams
training_params = TrainingParams(
checkpoint_dir="./model",
epochs=3,
strategy="ddp",
devices=[0, 1], # GPU indices to use
)
| Parameter | Default | Description |
|---|---|---|
devices |
[0] |
List of GPU indices to train on |
strategy |
None |
Distributed strategy — set to "ddp" for multi-GPU |
accelerator |
"auto" |
Hardware accelerator ("gpu", "cpu", or "auto") |
Choosing the Number of GPUs
- Single GPU (
devices=[0]) — simplest setup, suitable for most scenarios. Start here. - Two GPUs (
devices=[0, 1]) — halves data-loading time per GPU; good default for production runs. - More GPUs — useful for very large datasets or foundation models. Scaling beyond 4 GPUs has diminishing returns for most scenario models.
Adjust learning rate for multi-GPU training
The effective batch size scales with the number of GPUs. If you change devices, consider adjusting learning_rate proportionally (e.g. double the learning rate when doubling GPUs).
Example
The onboarding scripts use DDP by default:
# --- parallelised training ---
strategy = "ddp"
devices = [0, 1]
training_params = TrainingParams(
checkpoint_dir=scenario_model_path,
learning_rate=0.0001,
epochs=3,
devices=devices,
strategy=strategy,
)
To fall back to a single GPU, remove strategy and set devices=[0].
Troubleshooting
| Issue | Solution |
|---|---|
CUDA out of memory |
Reduce devices list or use limit_train_batches to cap batch count |
| Hanging at start | Ensure all listed GPU indices are available (nvidia-smi) |
| Different results across runs | DDP introduces non-determinism; set seed in trainer.fit() for closer reproducibility |
Synchronization Timeouts
Long multi-GPU runs can hit two distinct kinds of timeout. By default neither is set, and both are ignored on a single device.
from datetime import timedelta
from monad.ui.config import TrainingParams
training_params = TrainingParams(
checkpoint_dir="./model",
epochs=3,
strategy="ddp",
devices=[0, 1],
nccl_timeout=timedelta(minutes=30), # or just: 1800 (seconds)
rank_sync_timeout=timedelta(hours=1),
)
| Parameter | Default | Description |
|---|---|---|
nccl_timeout |
None |
Timeout for NCCL collective operations (the gradient all-reduce). Raise it when large or uneven batches make collectives run long. A bare number is read as seconds. |
rank_sync_timeout |
None |
Timeout for a dedicated per-step barrier that lets ranks wait on each other's data loading without tripping nccl_timeout. Set it when data-loading skew between ranks — not the gradient sync — is the bottleneck. |
With both set, rank_sync_timeout absorbs data-loading skew on a separate process group, so nccl_timeout only needs to cover the gradient synchronization itself.