Skip to content

Distributed Training

By default, training runs on a single GPU. For larger datasets or faster iteration, you can distribute training across multiple GPUs using PyTorch Lightning's distributed strategies.

Multi-GPU with DDP

The most common setup is DDP (Distributed Data Parallel), where each GPU processes a different subset of the data in parallel:

Python
from monad.ui.config import TrainingParams

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=3,
    strategy="ddp",
    devices=[0, 1],        # GPU indices to use
)
Parameter Default Description
devices [0] List of GPU indices to train on
strategy None Distributed strategy — set to "ddp" for multi-GPU
accelerator "auto" Hardware accelerator ("gpu", "cpu", or "auto")

Choosing the Number of GPUs

  • Single GPU (devices=[0]) — simplest setup, suitable for most scenarios. Start here.
  • Two GPUs (devices=[0, 1]) — halves data-loading time per GPU; good default for production runs.
  • More GPUs — useful for very large datasets or foundation models. Scaling beyond 4 GPUs has diminishing returns for most scenario models.

Adjust learning rate for multi-GPU training

The effective batch size scales with the number of GPUs. If you change devices, consider adjusting learning_rate proportionally (e.g. double the learning rate when doubling GPUs).

Example

The onboarding scripts use DDP by default:

Python
# --- parallelised training ---
strategy = "ddp"
devices = [0, 1]

training_params = TrainingParams(
    checkpoint_dir=scenario_model_path,
    learning_rate=0.0001,
    epochs=3,
    devices=devices,
    strategy=strategy,
)

To fall back to a single GPU, remove strategy and set devices=[0].

Troubleshooting

Issue Solution
CUDA out of memory Reduce devices list or use limit_train_batches to cap batch count
Hanging at start Ensure all listed GPU indices are available (nvidia-smi)
Different results across runs DDP introduces non-determinism; set seed in trainer.fit() for closer reproducibility

Synchronization Timeouts

Long multi-GPU runs can hit two distinct kinds of timeout. By default neither is set, and both are ignored on a single device.

Python
from datetime import timedelta
from monad.ui.config import TrainingParams

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=3,
    strategy="ddp",
    devices=[0, 1],
    nccl_timeout=timedelta(minutes=30),   # or just: 1800 (seconds)
    rank_sync_timeout=timedelta(hours=1),
)
Parameter Default Description
nccl_timeout None Timeout for NCCL collective operations (the gradient all-reduce). Raise it when large or uneven batches make collectives run long. A bare number is read as seconds.
rank_sync_timeout None Timeout for a dedicated per-step barrier that lets ranks wait on each other's data loading without tripping nccl_timeout. Set it when data-loading skew between ranks — not the gradient sync — is the bottleneck.

With both set, rank_sync_timeout absorbs data-loading skew on a separate process group, so nccl_timeout only needs to cover the gradient synchronization itself.