Fine-tuning training parameters
Check This First!
This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.
To configure scenario model training, you need to define your training parameters using the TrainingParams
object. At a minimum, you should specify where to save the model by setting the checkpoint_dir
parameter—unless you’re fine with the model not being stored at all.
All parameters have sensible default values, but you can override any of them by providing custom values in the training_params
. These include all constructor parameters supported by the PyTorch Lightning Trainer. For the complete list of configurable options, see this article.
Example |
---|
The example provided below demonstrates training parameters. In addition to specifying the scenario model location, some parameters have been overwritten or added.
training_params = TrainingParams(
checkpoint_dir='/location/to/save/your/scenario/model'
epochs=1,
learning_rate=0.0001,
overwrite=True,
devices=[2],
limit_train_batches=10
)
Callbacks
A Callback is an extension that can be used to supplement training with additional functionalities. Please refer to PyTorch documentation for more details.
Example |
---|
The example provided below demonstrates how to passTQDMProgressBar
callback to training parameters.
from pytorch_lightning.callbacks import TQDMProgressBar
training_params = TrainingParams(
checkpoint_dir='/location/to/save/your/scenario/model'
epochs=1,
callbacks=[TQDMProgressBar(refresh_rate=100)],
)
Metrics
Downstream models can be validated with any metric from TorchMetrics .
The table below contains default values of metric parameters for each downstream task.
Task |
|
|
|
---|---|---|---|
Binary Classification |
| val_auroc_0 |
|
Multiclass Classification |
| val_precision_0 |
|
Multi-label Classification |
| val_auroc_0 |
|
Regression |
| val_loss |
|
Recommendations |
| val_HR@10_0 |
|
Example |
---|
The example below demonstrates how to use recall and AUC during validation and how to monitor recall when selecting the best epoch.
training_params = TrainingParams(
checkpoint_dir='/location/to/save/your/scenario/model'
epochs=1,
metrics=[
{"alias": "auroc", "metric_name": "AUROC", "kwargs": {"task": "binary", "average": None}},
{"alias": "recall", "metric_name": "Recall", "kwargs": {"task": "binary"}},
],
metric_to_monitor="val_recall_0",
metric_monitoring_mode=MetricMonitoringMode.MAX
)
Loss Function
The default loss functions are proprietary extension of cross-entropy tailored to provide optimal training. Nevertheless, the loss function can be changed by either function imported from PyTorch loss functions or defined in Python.
Example |
---|
The examples provided below demonstrates how to change loss function in training parameters.
from torch.nn.functional import mse_loss
training_params = TrainingParams(
checkpoint_dir='/location/to/save/your/scenario/model'
epochs=1,
loss=mse_loss,
)
from torch.nn.functional import binary_cross_entropy_with_logits
def weighted_binary_cross_entropy_with_logits(
input, target, weight=None, size_average=None, reduce=None, reduction="mean"
):
return binary_cross_entropy_with_logits(
input,
target,
weight,
size_average,
reduce,
reduction,
pos_weight=torch.tensor([0.9], device="cuda:0")
)
training_params = TrainingParams(
checkpoint_dir='/location/to/save/your/scenario/model'
epochs=1,
loss=weighted_binary_cross_entropy_with_logits,
)
Important!
If you use custom loss function, remember to define it in all scripts that load trained model with the
load_from_checkpoint
function.
Updated 19 days ago