Making predictions with BaseModel
How does the inference work?
Check This First!
This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.
Once you have trained a downstream model, you will likely want to make predictions. To do this, you should prepare and execute a Python script with the following steps:
-
Import Required Packages, Classes and Functions: There are two required BaseModel.AI imports, but you may need additional ones if you want to use custom metrics, loggers, manipulate dates, etc.
-
Instantiate the Testing Module: Use the
load_from_checkpoint
method to load the best model according to the specified metric defined during training. -
Define Testing Parameters: Use the
TestingParams
class to define testing parameters. You can override the parameters configured during the training of the downstream model. -
Run Inference: The
predict()
method of the testing module will generate and save predictions.
Please see an example end-to-end script below:
from monad.ui.config import TestingParams
from monad.ui.module import load_from_checkpoint
from datetime import datetime
# declare variables
checkpoint_path = "<path/to/downstream/model/checkpoints>" # location of scenario model checkpoints
save_path = "<path/to/predictions/my_predictions.tsv>" # location to store predictions
test_start_date = datetime(2023, 8, 1) # first day of prediction period
test_end_date = datetime(2023, 8, 29) # last day of prediction period
# load scenario model to instantiate testing module
testing_module = load_from_checkpoint(
checkpoint_path = checkpoint_path,
test_start_date = test_start_date,
test_end_date = test_end_date
)
# define testing parameters
testing_params = TestingParams(
local_save_location = save_path
)
# run inference
testing_module.predict(testing_params = testing_params)
Necessary imports and variables
To make predictions in BaseModel, you need to import the required BaseModel functions and classes :
load_from_checkpoint
frommonad.ui.module
- to instantiate training module.TestingParams
frommonad.ui.config
- to configure your predictions.
Note
You may also add other optional imports that allow you to use other BaseModel functionalities, such as extra measures from
monad.ui.metrics
.
Instantiating the testing module
We instantiate the testing module by calling load_from_checkpoint
and providing checkpoint_path
(the location of your scenario model's checkpoints) along with any of the other optional arguments listed below.
This method will use the best of the model's checkpoints and the provided dataloaders to create an instance of the BaseModel module that enables the predict()
method, which we will use to generate predictions.
Arguments |
---|
- checkpoint_path : str, required
No default
The directory where all the checkpoint artifacts of the scenario model are stored.
- pl_logger : [Logger], optional
Default: None
An instance of PyTorch Lightning logger to use.
- loading_config : dict, [LoadingConfigParams], optional
Default: None
This parameter can either be:- A dictionary containing a mapping from the datasource name (or from the datasource name and mode) to the constructor arguments of
LoadingParams
. - Just the constructor arguments of
LoadingParams
.
If provided, the listed parameters will be overwritten.
Note that the fielddatasource_cfg
cannot be changed.
- A dictionary containing a mapping from the datasource name (or from the datasource name and mode) to the constructor arguments of
Additonally, as kwargs, you can pass any parameters defined in data_params
block in YAML configuration to overwrite those used during the training of the scenario model.
Good to know
It is in this module, that you define the time window to predict.
This is done withtest_start_date
and eithertest_end_date
orcheck_target_for_next_N_days
.
For example, if you want to predict the propensity to purchase a product within 21 days from a given date, you should define thetest_start_date
and either set thetest_end_date
21 days later, or set thecheck_target_for_next_N_days
to 21.
- check_target_for_next_N_days : int
default: None
The number of days, after the split point, considered for the model's target function period. Not applicable for recommendation tasks as these predict next basket regardless of time interval.
- test_start_date : datetime
default: None
Initial date of the test period. It will be used for downstream models' predictions, but it can be set later, as part of prediction script.
- test_end_date : datetime
default: None
The last date of the test period.
Have a look at an example of testing module instantiation with some additional arguments below.
testing_module = load_from_checkpoint(
checkpoint_path = "<path/to/downstream/model/checkpoints>",
test_start_date = datetime(2023, 8, 1),
check_target_for_next_N_days = 21,
pl_logger = neptune_logger # should be instantiated before
)
Sampling entities for inference
You can filter entities that BaseModel will generate predictions for. This is useful e.g. when you only want to predict propensity to buy a product for users that have never purchased it (acquisition) etc.
To make use of this functionality you need to use set_entities_ids_subquery
method.
You need to provide the data source, the query, and set DataMode
as TEST, like in the example below.
def subquery_fn(start_date: datetime, end_date: datetime):
return f"""
SELECT DISTINCT
client_id
FROM {table_name} AS events
WHERE client_id NOT IN (
SELECT client_id
FROM {table_name} AS sub_events
WHERE brand IN ('Nike', 'Adidas')
AND event_timestamp > '{start_date}' AND event_timestamp < '{end_date}'
)
"""
testing_module.set_entities_ids_subquery(
"transactions",
subquery_fn(data_params.data_start_date, data_params.test_start_date - 1),
DataMode.TEST,
)
Configuring Inference with TestingParams
TestingParams
We should now use TestingParams
, a class we imported from monad.ui.config
, to set up our predictions.
The only requirement is to specify where BaseModel should store the predictions, so you must provide either local_save_location
or remote_save_location
.
Parameters |
---|
- local_save_location: str
No default
If provided, points to the location in the local filesystem where predictions will be stored in TSV format.
- remote_save_location: str
No default
If provided, defines a table in a remote database where the predictions will be stored.
Additionally, stored parameters provided earlier as part of YAML training_params
or during scenario model training as TrainingParams
can also be modified. This is useful e.g. if you want to change the device, modify the number of recommendation targets, or add a new callback. Please refer to this article for the full list of modifiable parameters.
Example |
---|
The example provided below demonstrates testing parameters. In addition to specifying the location for predictions, some parameters have been overwritten or added. Note that the metric added as an argument requires importing an additional class.
from monad.ui.metrics import MultipleTargetsRecall
testing_params = TestingParams(
local_save_location="/data1/mcedro/eobuwie/predict/predict_8_08_2023.tsv",
metrics={"recall": MultipleTargetsRecall()},
devices=[0],
)
Running predictions
Having instantiated the testing module and configured the testing parameters, we are now ready to start the inference run. This is done by simply calling the predict()
method of the testing module and providing testing_params
as its argument.
testing_module.predict(testing_params = testing_params)
Updated 8 days ago