HomeGuidesRecipesChangelog
Log In
Guides

Model training configuration

training_params block in YAML configuration file

⚠️

Check This First!

This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.


training_params block allows you to set constructor parameters for PyTorch Lightning Trainer.
These settings (eg. learning rate, epochs etc.) influence the training process.

Parameters
  • epochs: int
    default: 1
    Number of epochs to train the model for.
  • learning_rate: float
    default: 0.0003
    The learning rate.
  • devices : Union[List[int], str, int, None]
    default:[0]
    The devices to use. Positive integer defines how many devices to use, a list of integers indices which devices should be used, the value -1 indicate that all available devices should be used.
  • accelerator : Literal["cpu", "gpu"]
    default: "gpu"
    The accelerator to use: GPU or CPU.
  • strategy: str
    default: None
    Strategy for the distributed training. Supported strategies are:
    • None: Pytorch Lightnings default strategy,
    • "ddp": Distributed Data Parallel,
    • "fsdp": Fully Sharded Data-Parallel 2 with a full tensor parallelism
    • "fsdp:%d:%d": Fully Sharded Data-Parallel 2 where first the int defines the data parallelism (replication) and the second int defines tensor parallelism (sharding).
  • precision : Literal[64, 32, 16, "64", "32", "16", "bf16", "16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"]
    default="DEFAULT_PRECISION"
    Controls Float precision used for training; double precision (64, ‘64’ or ‘64-true’), full precision (32, ‘32’ or ‘32-true’), 16bit mixed precision (16, ‘16’, ‘16-mixed’) or bfloat16 mixed precision (‘bf16’, ‘bf16-mixed’). DEFAULT_PRECISION constant sets precision to "bf16-mixed" if CUDA is available, else "16-mixed".
  • limit_train_batches : Union[int, float]
    default: None
    Limits the number of train batches per epoch (float = fraction, int = num_batches). Use eg. to speed up testing.
  • limit_val_batches : Union[int, float]
    default: None
    Limits the number of validation batches per epoch (float = fraction, int = num_batches). Use eg. to speed up testing.
  • loss : Callable | None
    default: None
    IGNORE AT FOUNDATION LEVEL STAGE. The loss function to use. If not provided, default loss function for a task will be used.\
  • metrics : [Dict[str, Metric]]
    default: None
    IGNORE AT FOUNDATION LEVEL STAGE. Metrics to use in validation. If not provided, default validation metrics function for a task will be used.
  • checkpoint_dir : [Union[str, Path]]
    default: None
    IGNORE AT FOUNDATION LEVEL STAGE. If provided, points to the location where model checkpoints will be stored.
  • metric_to_monitor : str
    default: None
    IGNORE AT FOUNDATION LEVEL STAGE. Determines which metric should be used to select the best model for saving.
  • metric_monitoring_mode : str ["min", "max"]
    default: None
    IGNORE AT FOUNDATION LEVEL STAGE. Indicates whether the smaller or greater value of selected metric is the better.
  • callbacks : list[Callback]
    default: Lightning factory default
    IGNORE AT FOUNDATION LEVEL STAGE. List of additional Pytorch Lightning callbacks to add to training.
  • resume : boolean
    default: False
    IGNORE AT FOUNDATION LEVEL STAGE. Whether to resume the training. If True, training will be resumed from the last checkpoint if such exists, an error will be thrown otherwise.
  • overwrite : boolean
    default: False
    IGNORE AT FOUNDATION LEVEL STAGE. Whether to overwrite the previous training results. If True, results will be overwritten. Otherwise, if resume is not set and checkpoints from previous training are present, error will be raised.
  • gradient_clip_val : Union[int, float]
    default: None
    Gradient clipping value (above which the gradients are clipped) passed to PytorchLightning trainer.
  • warm_start_steps : int
    default: 0
    Number of warm-start training steps used while fine-tuning a supervised model. Ignored if no pretrained model is used.
  • top_k : int
    default: 12
    IGNORE AT FOUNDATION LEVEL STAGE. Only valid at downstream model stage, for a recommendation task. Number of targets to recommend. Top k targets will be included in validation metrics, it does not have impact on model training.
  • targets_to_include : List[str]
    default: None
    IGNORE AT FOUNDATION LEVEL STAGE. Only valid at downstream model stage, for a recommendation task. Target names that should be used for validation metrics, it does not have impact on model training.
  • checkpoint_every_n_steps : int
    default: None
    Whether intra-epoch checkpointing should be performed.
  • early_stopping: monad.config.early_stopping.EarlyStopping
    default: None
    Whether to add early stopping callback to the training. If there is no improvement in the model's performance after subsequent validations, it will end the training before the defined number of epochs. It accepts the following sub-keys / keywords:
    • min_delta: float
      default: 0.0
      Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement. Should be greater than 0.
    • patience: int
      default: 3 Number of checks with no improvement after which training will be stopped. Relates to the checks performed comparing subsequent validation epoch metrics. Should be greater than 1.
    • verbose: bool
      default: false Whether to log information about registered improvement or lack of it.
  • entity_ids : monad.config.EntityIds
    default: None
    Restricts the set of entity IDs used during training or testing. It accepts the following sub-keys / keywords:
    • subquery : str | None
      A SQL subquery used to define which entity IDs should be included or excluded.

    • file : Path | None
      Path to a file containing the list of entity IDs to include or exclude.

    • matching : bool

      Determines the filtering mode for the IDs specified with subquery or file:

      • If True, only the specified IDs are used.
      • If False, the specified IDs are excluded.

Example

In the example below, we:

  • exclude specific entity IDs by loading them from a file at the given path,
  • set the learning rate to 0.0003
  • configure the training to run for exactly 3 epochs, without enabling early stopping,
  • define the directory where model checkpoints will be saved,
  • and specify GPU device [1] for the training run.

training_params:
  entity_ids:
    file: path/to/file
    matching: False
  learning_rate: 0.0003
  epochs: 3
  checkpoint_dir: "my_fm/"
  devices: [1]

📘

Did you know?

You can quickly test your configuration (eg. data sources) by limiting the epochs to 1 and your training batches per epoch to eg. 10. While the created model would not be very useful, this will drastically reduce the runtime, and let you check if all your settings are correct. For that, see the training_params below.


training_params:
  learning_rate: 0.0001
  epochs: 1
  limit_train_batches: 10
  limit_val_batches: 10