Fine-tuning training parameters

⚠️

Check This First!

This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.


Training parameters defined at foundation model training stage get loaded and are used by default when training the scenario model. Thus, the scenario model target location (the checkpoint_dir parameter) is the only mandatory parameter in scenario model training script.

However, all constructor parameters for PyTorch Lightning Trainer stored in MonadTrainingParams can be overwriten by providng them as input to training_params declaration.


Parameters
  • epochs: int
    default: 1
    Number of epochs to train the model for.

  • learning_rate: float
    default: 0.001
    The learning rate.

  • devices : Union[List[int], str, int, None]
    default: [0]
    The devices to use. Can be set to a positive number (int or str), a sequence of device indices(list or str), the value -1 to indicate all available devices should be used, or auto for automatic selection based on the chosen accelerator.

  • accelerator : Literal["cpu", "gpu"]
    default: "gpu"
    The accelerator to use: GPU or CPU.

  • precision : Literal[64, 32, 16, "64", "32", "16", "bf16", "16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"]
    default="DEFAULT_PRECISION"
    Controls Float precision used for training; double precision (64, ‘64’ or ‘64-true’), full precision (32, ‘32’ or ‘32-true’), 16bit mixed precision (16, ‘16’, ‘16-mixed’) or bfloat16 mixed precision (‘bf16’, ‘bf16-mixed’). DEFAULT_PRECISION constant sets precision to "bf16-mixed" if CUDA is available, else "16-mixed".

  • limit_train_batches : Union[int, float]
    default: None
    Limits the number of train batches per epoch (float = fraction, int = num_batches). Use eg. to speed up testing.

  • limit_val_batches : Union[int, float]
    default: None
    Limits the number of validation batches per epoch (float = fraction, int = num_batches). Use eg. to speed up testing.

  • loss : Callable | None
    default: None
    The loss function to use. If not provided, default loss function for a task will be used.

  • metrics : [Dict[str, Metric]]
    default: None
    Metrics to use in validation. If not provided, default validation metrics function for a task will be used.

  • checkpoint_dir : [Union[str, Path]]
    default: None
    The path to the location where model checkpoints should be stored.

  • metric_to_monitor : str
    default: None
    Determines which metric should be used to select the best model for saving.

  • metric_monitoring_mode : str ["min", "max"]
    default: None
    Indicates whether the smaller or greater value of selected metric is the better.

  • callbacks : list[Callback]
    default: Lightning factory default
    List of additional Pytorch Lightning callbacks to add to training.

  • resume : boolean
    default: False
    Whether to resume the training. If True, training will be resumed from the last checkpoint if such exists, an error will be thrown otherwise.

  • overwrite : boolean
    default: False
    Whether to overwrite the previous training results. If True, results will be overwritten. Otherwise, if resume is not set and checkpoints from previous training are present, error will be raised.

  • gradient_clip_val : Union[int, float]
    default: None
    Gradient clipping value (above which the gradients are clipped) passed to PytorchLightning trainer.

  • warm_start_steps : int
    default: 0
    Number of warm-start training steps used while fine-tuning a supervised model. Ignored if no pretrained model is used.

  • top_k : int
    default: 12
    Only valid for a recommendation task. Number of targets to recommend. Top k targets will be included in validation metrics, it does not have impact on model training.

  • targets_to_include : List[str]
    default: None
    Only valid for a recommendation task. Target names that should be used for validation metrics, it does not have impact on model training.

  • checkpoint_every_n_steps : int
    default: None
    Whether intra-epoch checkpointing should be performed.