Making predictions with BaseModel

How does the inference work?

⚠️

Check This First!

This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.

Once you have trained a downstream model, you will likely want to make predictions. To do this, you should prepare and execute a Python script with the following steps:

  1. Import Required Packages, Classes and Functions: There are two required BaseModel.AI imports, but you may need additional ones if you want to use custom metrics, loggers, manipulate dates, etc.

  2. Instantiate the Testing Module: Use the load_from_checkpoint method to load the best model according to the specified metric defined during training.

  3. Define Testing Parameters: Use the TestingParams class to define testing parameters. You can override the parameters configured during the training of the downstream model.

  4. Run Inference: The predict() method of the testing module will generate and save predictions.

Please see an example end-to-end script below:

from monad.ui.config import TestingParams
from monad.ui.module import load_from_checkpoint
from datetime import datetime

# declare variables
checkpoint_path = "<path/to/downstream/model/checkpoints>" # location of scenario model checkpoints
save_path = "<path/to/predictions/my_predictions.tsv>" # location to store predictions
test_start_date = datetime(2023, 8, 1) # first day of prediction period
test_end_date = datetime(2023, 8, 29) # last day of prediction period

# load scenario model to instantiate testing module
testing_module = load_from_checkpoint(
    checkpoint_path = checkpoint_path,
    test_start_date = test_start_date,
    test_end_date = test_end_date
)

# define testing parameters
testing_params = TestingParams(
    local_save_location = save_path
)

# run inference
testing_module.predict(testing_params = testing_params)

Necessary imports and variables

To make predictions in BaseModel, you need to import the required BaseModel functions and classes :

  • load_from_checkpoint from monad.ui.module - to instantiate training module.
  • TestingParams from monad.ui.config - to configure your predictions.

📘

Note

You may also add other optional imports that allow you to use other BaseModel functionalities, such as extra measures frommonad.ui.metrics.

Instantiating the testing module

We instantiate the testing module by calling load_from_checkpoint and providing checkpoint_path (the location of your scenario model's checkpoints) along with any of the other optional arguments listed below.
This method will use the best of the model's checkpoints and the provided dataloaders to create an instance of the BaseModel module that enables the predict() method, which we will use to generate predictions.

Arguments
  • checkpoint_path : str, required
    No default
    The directory where all the checkpoint artifacts of the scenario model are stored.

  • pl_logger : [Logger], optional
    Default: None
    An instance of PyTorch Lightning logger to use.

  • loading_config : dict, [LoadingConfigParams], optional
    Default: None
    This parameter can either be:
    • A dictionary containing a mapping from the datasource name (or from the datasource name and mode) to the constructor arguments of LoadingParams.
    • Just the constructor arguments of LoadingParams.
      If provided, the listed parameters will be overwritten.
      Note that the field datasource_cfg cannot be changed.

Additonally, as kwargs, you can pass any parameters defined in data_params block in YAML configuration to overwrite those used during the training of the scenario model.

📘

Good to know

It is in this module, that you define the time window to predict.
This is done with test_start_date and either test_end_date or check_target_for_next_N_days.
For example, if you want to predict the propensity to purchase a product within 21 days from a given date, you should define the test_start_dateand either set the test_end_date 21 days later, or set the check_target_for_next_N_days to 21.

  • check_target_for_next_N_days : int
    default: None
    The number of days, after the split point, considered for the model's target function period. Not applicable for recommendation tasks as these predict next basket regardless of time interval.

  • test_start_date : datetime
    default: None
    Initial date of the test period. It will be used for downstream models' predictions, but it can be set later, as part of prediction script.

  • test_end_date : datetime
    default: None
    The last date of the test period.

Have a look at an example of testing module instantiation with some additional arguments below.

testing_module = load_from_checkpoint(
    checkpoint_path = "<path/to/downstream/model/checkpoints>",
    test_start_date = datetime(2023, 8, 1),
    check_target_for_next_N_days = 21,
    pl_logger = neptune_logger # should be instantiated before
)

Sampling entities for inference

You can filter entities that BaseModel will generate predictions for. This is useful e.g. when you only want to predict propensity to buy a product for users that have never purchased it (acquisition) etc.

To make use of this functionality you need to use set_entities_ids_subquery method.
You need to provide the data source, the query, and set DataMode as TEST, like in the example below.

def subquery_fn(start_date: datetime, end_date: datetime):
    return f"""
        SELECT DISTINCT
            client_id
        FROM {table_name} AS events
        WHERE client_id NOT IN (
            SELECT client_id
            FROM {table_name} AS sub_events
            WHERE brand IN ('Nike', 'Adidas') 
            AND event_timestamp > '{start_date}' AND event_timestamp < '{end_date}'
        )
    """

testing_module.set_entities_ids_subquery(
    "transactions",
    subquery_fn(data_params.data_start_date, data_params.test_start_date - 1),
    DataMode.TEST,
)

Configuring Inference with TestingParams

We should now use TestingParams, a class we imported from monad.ui.config, to set up our predictions.
The only requirement is to specify where BaseModel should store the predictions, so you must provide either local_save_location or remote_save_location.

Parameters
  • local_save_location: str
    No default
    If provided, points to the location in the local filesystem where predictions will be stored in TSV format.

  • remote_save_location: str
    No default
    If provided, defines a table in a remote database where the predictions will be stored.

Additionally, stored parameters provided earlier as part of YAML training_params or during scenario model training as TrainingParamscan also be modified. This is useful e.g. if you want to change the device, modify the number of recommendation targets, or add a new callback. Please refer to this article for the full list of modifiable parameters.

Example

The example provided below demonstrates testing parameters. In addition to specifying the location for predictions, some parameters have been overwritten or added. Note that the metric added as an argument requires importing an additional class.

from monad.ui.metrics import MultipleTargetsRecall

testing_params = TestingParams(  
        local_save_location="/data1/mcedro/eobuwie/predict/predict_8_08_2023.tsv",  
        metrics={"recall": MultipleTargetsRecall()},  
        devices=[0],  
    )

Running predictions

Having instantiated the testing module and configured the testing parameters, we are now ready to start the inference run. This is done by simply calling the predict() method of the testing module and providing testing_params as its argument.

testing_module.predict(testing_params = testing_params)