Fine-tuning the Training Process

How to configure the loading and training process?

In this module, we configure model and training parameters of the downstream task, and write the target function that defines the model's target.

By default, the training parameters are defined as in Pretrain config. It means, that:

  1. Data sources are loaded automatically.
  2. Dates for training, validation and test sets are defined as in Pretrain config.
  3. All model training parameters for training downstream task are defined as in Foundation Model.

This is done for simplicity and with the aim of not replicating unnecessary code. However, we understand that parameters set for your downstream task can differ from the ones used in Foundation Model, so we allow to change them.

We dive into two scenarios in the incoming sections.

Loading Foundation Model

In order to load Foundation Model trained in Pretrain phase, you need to use load_from_foundation_model method.

Example for churn model.

from monad.ui.module import load_from_foundation_model
from monad.ui.module import BinaryClassificationTask

trainer = load_from_foundation_model(
        target_fn=churn_target_fn, num_outputs=1

Example for propensity model.

from monad.ui.module import load_from_foundation_model
from monad.ui.module import MultilabelClassificationTask

trainer = load_from_foundation_model(
        target_fn=propensity_target_fn, num_targets=10

Example for recommendation model.

from monad.ui.module import load_from_foundation_model
from monad.ui.module import RecommendationTask

trainer = load_from_foundation_model(

Creates MonadModuleImpl from MonadCheckpoint where model saved under checkpoint_path is assumed to be a Foundation Model.


checkpoint_pathstrDirectory where all the checkpoint artifacts are stored.required
downstream_taskTaskone of machine learning tasks defined in BaseModel. Possible values are RecommendationTask(), BinaryClassificationTask(), MultilabelClassificationTask(), MulticlassClassificationTask()required
target_fnCallable[[Events, Events, Attributes, Dict], Union[Tensor, ndarray, Sketch]]Target function for the specified task.required
pl_loggerOptional[Logger]Instance of PytorchLightning logger.None
loading_configOptional[LoadingConfigParams]A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed.None

Additionally, you can pass any parameters defined in MonadDataParams in order to overwrite parameters configured for Foundation Model training:


features_pathstrA path to the folder with features created during the pretrain phase.required
data_start_datedatetimeEvents after this date will be considered for training.required
check_target_for_next_N_daysintThe number of days used to create the model's target. Not suitable for recommendation models.None
validation_start_datedatetimestart date for the validation setNone
test_start_datedatetimeThe date that the prediction is being calculated for. validation_start_date or test_start_date needs to be provided.None
test_end_datedatetimeEnd date of the test period - prediction end date.None
timebased_encodingstrHow to encode time based features; available encoding options are "fourier" or "two-hot".'two-hot'
target_sampling_strategystr"valid" or "random" sampling strategy. For Foundation Model it should always be "random".'random'
maximum_splitpoints_per_entityintThe maximum number of splits into input and target events per entity.1
num_query_chunksintThis parameter represents the number of segments a query should be divided into. Splitting the query into smaller pieces can help reduce memory consumption on the database end.1
use_recency_sketchesbooleanIf true then recency sketches are used in trainingTrue

Training downstream task

Once you loaded the Foundation Model, you should specify checkpoint_dir in MonadTrainingParams and overwrite any default parameters you need.

Constructor parameters for PyTorch Lightning Trainer.


epochsintNumber of epochs to train.1
learning_ratefloatThe learning rate.0.001
devicesUnion[List[int], str, int, None]The devices to use. Can be set to a positive number (int or str), a sequence of device indices(list or str), the value -1 to indicate all available devices should be used, or auto for automatic selection based on the chosen accelerator.field(default_factory=lambda : [0])
acceleratorLiteral['cpu', 'gpu']The accelerator to use, as in PytorchLightning trainer.'gpu'
precisionLiteral[64, 32, 16, '64', '32', '16', 'bf16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true']Double precision (64, ‘64’ or ‘64-true’), full precision (32, ‘32’ or ‘32-true’), 16bit mixed precision (16, ‘16’, ‘16-mixed’) or bfloat16 mixed precision (‘bf16’, ‘bf16-mixed’).DEFAULT_PRECISION
limit_train_batchesOptional[Union[int, float]]How much of training dataset to check (float = fraction, int = num_batches).1.0
limit_val_batchesOptional[Union[int, float]]How much of validation dataset to check (float = fraction, int = num_batches).1.0
lossOptional[Callable]The loss function to use. If not provided, default loss function for a task will be used.None
metricsOptional[Dict[str, Metric]]Metrics to use in validation. If not provided, default validation metrics function for a task will be used.None
checkpoint_dirOptional[Union[str, Path]]If provided, points the location where checkpoints will be stored.None
metric_to_monitorOptional[str]Decides what metric is responsible for determining the model quality for saving the state dict.None
metric_monitoring_modeOptional[MetricMonitoringMode]whether the smaller or greater value of the metric is the better. Possible values are "min" and "max".None
callbacksList[Callback]List of additional callbacks to add to traininglist()
gradient_clip_valOptional[Union[int, float]]Gradient clipping value passed to PytorchLightning trainerNone
warm_start_stepsintNumber of warm-start training steps used while fine-tuning a supervised model. Ignored if no pretrained model is used.0
top_kintOnly for recommendation task. Number of targets to recommend. Top k targets will be included in validation metrics, it doesn't have impact on model training. Default value: 12.12
targets_to_includeList[str]Only for recommendation task. Target names that should be used for validation metrics, it doesn't have impact on model training.None

Task-specific training

BaseModelAI supports a broad range of ML tasks:

Classification Tasks

  • BinaryClassificationTask — Each data point belongs to one of two categories, e.g. churn prediction; with binary classification task, we use binary cross-entropy with logits loss function.
  • MulticlassClassificationTask — Each data point belongs to one of multiple categories (3 or more), e.g. predicting the user's favorite brand; with multi-class classification, we use cross-entropy loss function.
  • MultilabelClassificationTask — Each data point can be assigned with zero or more labels, e.g. predicting the probability that the user will buy a product in some predefined categories; with multilabel classification task, we use binary cross-entropy with logits loss function.

Recommendation Tasks

  • RecommendationTask — For each data example, a list of n items is returned, e.g. personalized recommendation of offers or products.

Regression Tasks

  • RegressionTask — Predict continuous values based on the provided input, e.g. predicting how much the user will spend in the next year; with regression task, we use cross-entropy loss function.


Data split

For training purposes, BaseModelAI can split data into train, validation, and test sets. It is done based on events' time windows configured for each phase. The data_start_date is our dataset's starting point — only events after this date will be considered.

To create train, validation, and test sets, we use subsequent dates validation_start_date and test_start_date to define targets. At each stage, we consider user history from data_start_date; what differs are split points that define which portion of events will be used as an input and which will be used to create the model's target.

In configuration file of the downstream task, you must set at least validation_start_date or test_start_date.



If you do not set validation_start_date but set test_start_date, you cannot train downstream task, only run prediction using already trained downstream model.

Classification and regression

For training purposes, we consider user history since data_start_date until training_max_split_date, which is automatically calculated based on validation_start_date:

training_max_split_date = validation_start_date - 1 - check_target_for_next_N_days

We randomly select split points between data_start_date and training_max_split_date. Thanks to that, if the latest possible split point is chosen — the one on training_max_split_date — we train on the next check_target_for_next_N_days which are NOT included in the validation period, so no data leakage occurs. Multiple split points can be selected for one entity. The number of split points is governed by the maximum_splitpoints_per_entity attribute. This enables us to create multiple data points from single user history.

Analogously, for testing, we use history from data_start_date to test_start_date as an input and use subsequent check_target_for_next_N_days to create the target.

For validation purposes, to create the model input, we use the whole user history from data_start_date to
validation_start_date. To create the target, we use check_target_for_next_N_days after validation_start_date. This way, we use all available history to make predictions, while preventing data leakage by predicting events that were not seen in the training phase.


In recommendation setup, we predict the next item in the basket, so we do not set check_target_for_next_N_days parameter, because it is not relevant.
For recommendations, training_max_split_date defaults to validation_start_date.
If we want to set aside test dataset, we must define test_start_date. If not, the whole period will be used.

Target function

Please refer to the Model Target Function section for details on how to create the target function that will fulfill your business needs.