Guides

Fine-tuning training parameters

⚠️

Check This First!

This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.


Training parameters defined at foundation model training stage get loaded and are used by default when training the scenario model. Thus, the scenario model target location (the checkpoint_dir parameter) is the only mandatory parameter in scenario model training script.

However, all stored constructor parameters for PyTorch Lightning Trainer can be overwritten by providing them as input to training_params declaration. Please refer to this article for the full list of modifiable parameters.

Example

The example provided below demonstrates training parameters. In addition to specifying the scenario model location, some parameters have been overwritten or added.

training_params = TrainingParams(
    checkpoint_dir='/location/to/save/your/scenario/model'
    epochs=1,
    learning_rate=0.0001,
    overwrite=True,
    devices=[2],
    limit_train_batches=10
)

Callbacks

A Callback is an extension that can be used to supplement training with additional functionalities. Please refer to PyTorch documentation for more details.

Example

The example provided below demonstrates how to passTQDMProgressBar callback to training parameters.

from pytorch_lightning.callbacks import TQDMProgressBar

training_params = TrainingParams(
    checkpoint_dir='/location/to/save/your/scenario/model'
    epochs=1,
    callbacks=[TQDMProgressBar(refresh_rate=100)],
)

Metrics

Downstream models can be validated with any metric from TorchMetrics .

The table below contains default values of metric parameters for each downstream task.

Task

metrics

metric_to_monitor

metric_monitoring_mode

Binary Classification

AUROC(
      num_labels=num_classes,
    task="binary",
    average=None,
),
AveragePrecision(
    num_labels=num_classes,
    task="binary",
    average=None,
)

val_auroc_0

MetricMonitoringMode.MAX

Multiclass Classification

Precision(
    num_classes=num_classes,
    task="multiclass",
    average=None,
),
Recall(
    num_classes=num_classes,
    task="multiclass",
    average=None,
)

val_precision_0

MetricMonitoringMode.MAX

Multi-label Classification

AUROC(
    num_labels=num_classes,
    task="multilabel",
    average=None,
),
AveragePrecision(
    num_labels=num_classes,
    task="multilabel",
    average=None,
)`

val_auroc_0

MetricMonitoringMode.MAX

Regression

MeanSquaredError(squared=False)

val_loss

MetricMonitoringMode.MIN

Recommendations

HitRateAtK(k=1),
HitRateAtK(k=10),
HitRateAtK(k=25),
HitRateAtK(k=50),
HitRateAtK(k=100),
HitRateAtK(k=top_k),
MeanAveragePrecisionAtK(k=12),
MeanAveragePrecisionAtK(
    k=top_k
),
PrecisionAtK(k=10),
PrecisionAtK(k=top_k)
where top_k defaults to 12.

val_HR@10_0

MetricMonitoringMode.MAX


Example

The example below demonstrates how to use recall and AUC during validation and how to monitor recall when selecting the best epoch.

training_params = TrainingParams(
    checkpoint_dir='/location/to/save/your/scenario/model'
    epochs=1,
    metrics=[
           {"alias": "auroc", "metric_name": "AUROC", "kwargs": {"task": "binary", "average": None}},
           {"alias": "recall", "metric_name": "Recall", "kwargs": {"task": "binary"}},
       ],
     metric_to_monitor="val_recall_0",
     metric_monitoring_mode=MetricMonitoringMode.MAX
)

Loss Function

The default loss functions are proprietary extension of cross-entropy tailored to provide optimal training. Nevertheless, the loss function can be changed by either function imported from PyTorch loss functions or defined in Python.


Example

The examples provided below demonstrates how to change loss function in training parameters.

from torch.nn.functional import mse_loss

training_params = TrainingParams(
    checkpoint_dir='/location/to/save/your/scenario/model'
    epochs=1,
    loss=mse_loss,
)
from torch.nn.functional import binary_cross_entropy_with_logits

    def weighted_binary_cross_entropy_with_logits(
        input, target, weight=None, size_average=None, reduce=None, reduction="mean"
    ):
        return binary_cross_entropy_with_logits(
          input, 
          target,
          weight, 
          size_average, 
          reduce, 
          reduction, 
          pos_weight=torch.tensor([0.9], device="cuda:0")
    )

training_params = TrainingParams(
    checkpoint_dir='/location/to/save/your/scenario/model'
    epochs=1,
    loss=weighted_binary_cross_entropy_with_logits,
)

⚠️

Important!

If you use custom loss function, remember to define it in all scripts that load trained model with the load_from_checkpoint function.