Custom Loss & Callbacks
Custom Loss Function
The default loss functions are optimized for each task type. You can replace them with any loss function from PyTorch or written in Python.
Using a PyTorch Loss
from torch.nn.functional import mse_loss
from monad.ui.config import TrainingParams
training_params = TrainingParams(
checkpoint_dir="./model",
epochs=1,
loss=mse_loss,
)
Defining a Custom Loss
from torch import tensor
from torch.nn.functional import binary_cross_entropy_with_logits
from monad.ui.config import TrainingParams
def weighted_bce(input, target, weight=None, size_average=None,
reduce=None, reduction="mean"):
return binary_cross_entropy_with_logits(
input, target, weight, size_average, reduce, reduction,
pos_weight=tensor([0.9], device="cuda:0"),
)
training_params = TrainingParams(
checkpoint_dir="./model",
epochs=1,
loss=weighted_bce,
)
Redefine custom loss in all load scripts
If you use a custom loss function, define it in every script that loads the trained model via load_from_checkpoint. The loss function must be importable at load time.
Callbacks
Attach PyTorch Lightning callbacks to supplement training with additional functionality — progress bars, custom logging, learning-rate scheduling, etc.
from pytorch_lightning.callbacks import TQDMProgressBar
from monad.ui.config import TrainingParams
training_params = TrainingParams(
checkpoint_dir="./model",
epochs=1,
callbacks=[TQDMProgressBar(refresh_rate=100)],
)
Pass callbacks as a list via TrainingParams.callbacks. Any callback compatible with the PyTorch Lightning Trainer can be used.
Early Stopping
Early stopping halts training when the monitored metric stops improving, preventing overfitting and saving compute time. Configure it via TrainingParams.early_stopping:
from monad.ui.config import TrainingParams, EarlyStopping, MetricMonitoringMode
training_params = TrainingParams(
checkpoint_dir="./model",
epochs=20,
metric_to_monitor="val_auroc_0",
metric_monitoring_mode=MetricMonitoringMode.MAX,
early_stopping=EarlyStopping(patience=5),
)
| Parameter | Default | Description |
|---|---|---|
patience |
required | Number of epochs with no improvement before stopping |
min_delta |
0.0 |
Minimum change to qualify as an improvement |
verbose |
False |
Log a message when early stopping triggers |
Use early stopping to prevent overfitting
Without early stopping, the best checkpoint is still saved based on metric_to_monitor, but training runs for the full epochs count. Early stopping is most useful when you set a high epochs ceiling and want training to finish as soon as gains plateau.