Run Predictions

Once you train downstream model, you probably wish to make predictions. In order to do that, you must:

  1. Load trained downstream model using load_from_checkpoint method. We load the best model according to the specified metric, defined during training.
  2. Define testing parameters in MonadTestingParams class. You can overwrite parameters configured for training of downstream model.
checkpoint_dir = "<path/to/downstream/model/checkpoints>"
testing_module = load_from_checkpoint(checkpoint_dir)

Creates MonadModuleImpl from MonadCheckpoint.


checkpoint_pathstrDirectory where all the checkpoint artifacts are stored.required
pl_loggerOptional[Logger]An instance of PyTorch Lightning logger to use.None
loading_configOptional[LoadingConfigParams]A dictionary containing a mapping from datasource name (or from datasource name and mode) to the fields of DataSourceLoadingConfig. If provided, the listed parameters will be overwritten. Field datasource_cfg can't be changed.None
kwargsData parameters to change.{}


MonadModuleImplMonadModuleImplInstance of monad module, loaded from the checkpoint.

Additonally, you can pass any parameters defined in MonadDataParams in order to overwrite parameters configured for training of downstream model:


Good to know

It is in this module, that you will define the prediction window, i.e. the time window that you want to predict for your target function.

For example, if you plan to predict the propensity to purchase something within 21 days from a given date, you need to define this using test_start_date and check_target_for_next_N_days = 21


features_pathstrA path to the folder with features created during the pretrain phase.required
data_start_datedatetimeEvents after this date will be considered for training.required
check_target_for_next_N_daysintThe number of days used to create the model's target. Not suitable for recommendation models.None
validation_start_datedatetimeStart date for the validation set.None
test_start_datedatetimeThe date that the prediction is being calculated for. validation_start_date or test_start_date needs to be provided.None
test_end_datedatetimeEnd date of the test period.None
timebased_encodingstrHow to encode time-based features; available encoding options are "fourier" or "two-hot".'two-hot'
target_sampling_strategystr"valid" or "random" sampling strategy. For Foundation Model, it should always be "random".'random'
maximum_splitpoints_per_entityintThe maximum number of splits into input and target events per entity.1
num_query_chunksintThe number of segments a query should be divided into to reduce memory consumption on the database end.1
use_recency_sketchesbooleanIf true, then recency sketches are used in training.True

Then, instantiate monad.core.config.MonadTestingParams with the provided parameters. If not specified, they will be overwritten with parameters from the downstream training module.

from monad.ui.config import MonadTestingParams

testing_params = MonadTestingParams(


save_pathstrIf provided, points to the location where predictions will be stored in CSV format.required
limit_test_batchesOptional[Union[int, float]]How much of the test dataset to check (float = fraction, int = num_batches).None
devicesUnion[List[int], str, int, None]The devices to use. Can be set to a positive number (int or str), a sequence of device indices(list or str), the value -1 to indicate all available devices should be used, or auto for automatic selection based on the chosen accelerator.field(default_factory=lambda : [0])
acceleratorstrThe accelerator to use, as in PyTorch Lightning trainer.'gpu'
precisionLiteral[64, 32, 16, '64', '32', '16', 'bf16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true']Double precision, full precision, 16bit mixed precision or bfloat16 mixed precision.DEFAULT_PRECISION
metricsDict[str, Metric]Metrics to use in validation. If not provided, default validation metrics function for a task will be used.None
callbacksList[Callback]List of additional callbacks to add to validation/testing.field(default_factory=list)
top_kintOnly for recommendation task. Number of targets to recommend. Top k targets will be included in predictions.12
targets_to_includeList[str]Only for recommendation task. Target names that will be included in predictions.None

and finally make predictions:



Did you know?

You can use subqueries to filter users that you run predictions on. For example to only calculate propensity for product A for users that never purchase this particular product and get the list of users with highest propensity to purchase it.

To make use of this functionality you need to use set_entities_ids_subqueryaccessible from load_from_checkpoint and load_from_foundation_model methods and provide a SQL query in the flavor corresponding to the database you are using.