Skip to content

Custom Loss & Callbacks

Custom Loss Function

The default loss functions are optimized for each task type. You can replace them with any loss function from PyTorch or written in Python.

Using a PyTorch Loss

Python
from torch.nn.functional import mse_loss
from monad.ui.config import TrainingParams

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=1,
    loss=mse_loss,
)

Defining a Custom Loss

Python
from torch import tensor
from torch.nn.functional import binary_cross_entropy_with_logits
from monad.ui.config import TrainingParams

def weighted_bce(input, target, weight=None, size_average=None,
                 reduce=None, reduction="mean"):
    return binary_cross_entropy_with_logits(
        input, target, weight, size_average, reduce, reduction,
        pos_weight=tensor([0.9], device="cuda:0"),
    )

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=1,
    loss=weighted_bce,
)

Redefine custom loss in all load scripts

If you use a custom loss function, define it in every script that loads the trained model via load_from_checkpoint. The loss function must be importable at load time.

Callbacks

Attach PyTorch Lightning callbacks to supplement training with additional functionality — progress bars, custom logging, learning-rate scheduling, etc.

Python
from pytorch_lightning.callbacks import TQDMProgressBar
from monad.ui.config import TrainingParams

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=1,
    callbacks=[TQDMProgressBar(refresh_rate=100)],
)

Pass callbacks as a list via TrainingParams.callbacks. Any callback compatible with the PyTorch Lightning Trainer can be used.

Early Stopping

Early stopping halts training when the monitored metric stops improving, preventing overfitting and saving compute time. Configure it via TrainingParams.early_stopping:

Python
from monad.ui.config import TrainingParams, EarlyStopping, MetricMonitoringMode

training_params = TrainingParams(
    checkpoint_dir="./model",
    epochs=20,
    metric_to_monitor="val_auroc_0",
    metric_monitoring_mode=MetricMonitoringMode.MAX,
    early_stopping=EarlyStopping(patience=5),
)
Parameter Default Description
patience required Number of epochs with no improvement before stopping
min_delta 0.0 Minimum change to qualify as an improvement
verbose False Log a message when early stopping triggers

Use early stopping to prevent overfitting

Without early stopping, the best checkpoint is still saved based on metric_to_monitor, but training runs for the full epochs count. Early stopping is most useful when you set a high epochs ceiling and want training to finish as soon as gains plateau.