HomeGuidesRecipesChangelog
Guides

Customizing the loading from foundation model

⚠️

Check This First!

This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.


The arguments of the load_from_foundation_model are required to instantiate the scenario model trainer and specify the location of the foundation model for the scenario, the modelling task and expected output, and the target function. Optionally, they let you also e.g. use a customized logger, adapt the data source loading configuration, and modify majority of parameters described in Data configuration section of the foundation model.

Scenario Model Parameters
  • checkpoint_path : str
    No default, required.
    Directory where all the checkpoint artifacts of the selected foundation model are stored.
  • downstream_task : Task
    No default, required
    One of machine learning tasks defined in BaseModel. Possible values are RegressionTask(), RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask()`.
  • target_fn : Callable[[Events, Events, Attributes, Dict], Union[Tensor, ndarray, Sketch]]
    No default, required
    Target function for the specified task. Needs definition in the script and return type aligned with the task.
  • with_head: bool
    Default: False
    Whether to use last layer from the foundation model. May improve the quality of recommendation tasks.
  • split: Optional[dict[DataMode, TimeRange]]
    Default: None
    Overwrites split defined during foundation model training phase in case when time-based split was used. If the test period was not defined in the pretrain configuration file, it can be assigned here. You need to specify DataMode (TRAIN - training dataset, VALIDATION - validation dataset, TEST - test dataset ), and provide dates in the TimeRange object, for example: {DataMode.TEST: TimeRange(start_date=datetime(2023, 8, 1), end_date=datetime(2023, 8, 22)}
  • pl_logger : Logger
    Default: None, optional
    Instance of PytorchLightning logger.

📘

Good to know

Additionally, as part of load_from_foundation_model input, you can expand or overwrite configurations made during the Foundation Model training stage.

  • data_params: dates to separate training, validation and test sets, managing sampling, number of split points, declaring extra columns to be available for target function etc.
  • query_optimization: chunking query, capping sample size or CPUs in case of infrastructure constraints.

To review the list of modifiable parameters refer to:


Extra columns

In this section, we demonstrate how to include additional columns that were not used during model training. These columns can be accessed within the target function when needed. To learn how to retrieve them, see the Target Function: Extra Columns.

To include additional columns, you need to use the extra_columns parameter within data_params (see Data configuration). This can be passed to the load_from_foundation_model function.

The example below demonstrates how to include the basket_id column from the transactions data source. Additionally it shows how to include extra columns from SQL expressions, in this case column color from joined datasource products converted to lower letters.

from monad.ui.module import load_from_foundation_model
from monad.ui.module import  BinaryClassificationTask 

extra_columns = [{
  "data_source_name": "transactions", 
  "columns": [
    {
      "alias":"basket_id", 
      "expression":"basket_id",
    },
    {
      "alias":"color", 
      "expression": "LOWER({{ resolve_fn('color', data_sources_path=['products']) }})",
    },
  ]
}]


trainer = load_from_foundation_model(
	checkpoint_path =  "/path/to/your/models/pretrain/fm",
  downstream_task = BinaryClassificationTask(), 
  target_fn = bin_target_fn, # your target function 
	extra_columns = 
)

📘

Good to know

In additional SQL expressions, column names should be passed wrapped in {{resolve_fn(column_name, data_sources_path=['A', 'B'])}}, where column_name represents name of the column in 'B' data source and data_sources_path represent join hierarchy. Data source name in which SQL expression is defined should be omitted. For the columns belonging to the top-level data source, it is possible to pass them without wrapping, but all necessary escaping is on the user side.