Making predictions with BaseModel
How does the inference work?
Check This First!
This article refers to BaseModel accessed via Docker container. Please refer to Snowflake Native App section if you are using BaseModel as SF GUI application.
Once you have trained a downstream model, you will likely want to make predictions. To do this, you should prepare and execute a Python script with the following steps:
-
Import Required Packages, Classes and Functions: There are two required BaseModel.AI imports, but you may need additional ones if you want to use custom metrics, loggers, manipulate dates, etc.
-
Instantiate the Testing Module: Use the
load_from_checkpoint
method to load the best model according to the specified metric defined during training. -
Define Testing Parameters: Use the
TestingParams
class to define testing parameters. You can override the parameters configured during the training of the downstream model. -
Run Inference: The
predict()
method of the testing module will generate and save predictions.
Please see an example end-to-end script below:
from monad.ui.config import OutputType, TestingParams
from monad.ui.module import load_from_checkpoint
from datetime import datetime
# declare variables
checkpoint_path = "<path/to/downstream/model/checkpoints>" # location of scenario model checkpoints
save_path = "<path/to/predictions/my_predictions.tsv>" # location to store predictions
test_start_date = datetime(2023, 8, 1) # first day of prediction period
test_end_date = datetime(2023, 8, 29) # last day of prediction period
# load scenario model to instantiate testing module
testing_module = load_from_checkpoint(
checkpoint_path = checkpoint_path,
test_start_date = test_start_date,
test_end_date = test_end_date
)
# define testing parameters
testing_params = TestingParams(
local_save_location = save_path,
output_type = OutputType.DECODED,
)
# run inference
testing_module.predict(testing_params = testing_params)
Necessary imports and variables
To make predictions in BaseModel, you need to import the required BaseModel functions and classes :
load_from_checkpoint
frommonad.ui.module
- to instantiate training module.TestingParams
frommonad.ui.config
- to configure your predictions.
Note
You may also add other optional imports that allow you to use other BaseModel functionalities, such as extra measures from
monad.ui.metrics
.
Instantiating the testing module
We instantiate the testing module by calling load_from_checkpoint
and providing checkpoint_path
(the location of your scenario model's checkpoints) along with any of the other optional arguments listed below.
This method will use the best of the model's checkpoints and the provided dataloaders to create an instance of the BaseModel module that enables the predict()
method, which we will use to generate predictions.
Arguments |
---|
- checkpoint_path : str, required
No default
The directory where all the checkpoint artifacts of the scenario model are stored. - pl_logger : [Logger], optional
Default: None
An instance of PyTorch Lightning logger to use. - loading_config : dict, [LoadingConfigParams], optional
Default: None
This parameter can either be:- A dictionary containing a mapping from the datasource name (or from the datasource name and mode) to the constructor arguments of
LoadingParams
. - Just the constructor arguments of
LoadingParams
.
If provided, the listed parameters will be overwritten. Note that the fielddatasource_cfg
cannot be changed.
- A dictionary containing a mapping from the datasource name (or from the datasource name and mode) to the constructor arguments of
Additonally, as kwargs, you can pass any parameters defined in data_params
block in YAML configuration to overwrite those used during the training of the scenario model.
Good to know
It is in this module, that you define the time window to predict.
This is done withtest_start_date
and eithertest_end_date
. For example, if you want to predict the propensity to purchase a product within 21 days from a given date, you should define thetest_start_date
and set thetest_end_date
21 days later.
- test_start_date : datetime
default: None
Initial date of the test period. It will be used for downstream models' predictions, but it can be set later, as part of prediction script. - test_end_date : datetime
default: None
The last date of the test period.
Have a look at an example of testing module instantiation with some additional arguments below.
testing_module = load_from_checkpoint(
checkpoint_path = "<path/to/downstream/model/checkpoints>",
test_start_date = datetime(2023, 8, 1),
test_end_date = datetime(2023, 8, 22),
pl_logger = neptune_logger # should be instantiated before
)
Sampling entities for inference
You can filter entities that BaseModel will generate predictions for. This is useful e.g. when you only want to predict propensity to buy a product for users that have never purchased it (acquisition) etc.
To make use of this functionality you need to use set_entities_ids_subquery
method.
You need to provide the data source, the query, and set DataMode
as TEST, like in the example below.
from monad.ui.config import DataMode, EntityIds
testing_module.set_entities_ids_subquery(
query=EntityIds(subquery="SELECT DISTINCT client_id FROM transactions WHERE client_group = 1"),
mode=DataMode.TEST,
)
Sampling can be defined with EntityIds
class imported from monad.ui.config
Parameters |
---|
- subquery: str
Default: None
Subquery used to select entity ids that ought to be used during training. - file: Path
Default: None
Path to a file containing entity ids that ought to be used during training. Each entity id should be in a separate row. Currently supported only for Snowflake DB. - matching: bool
Default: True
Whether ids specified by eithersubquery
orfile
should be included or excluded.
Configuring Inference with TestingParams
TestingParams
We should now use TestingParams
, a class we imported from monad.ui.config
, to set up our predictions.
Parameters |
---|
- output_type: OutputType
No default
Output format in which to save the predictions. The table below explains how different values affect prediction outputs.
Task | OutputType.RAW_MODEL | OutputType.ENCODED | OutputType.DECODED | OutputType.SEMANTIC |
---|---|---|---|---|
Binary Classification | Raw model output | Raw model output | Recommended Sigmoid of the raw model output | Sigmoid of the raw model output |
Multiclass Classification | Raw model output | Raw model output | Recommended Softmax of the raw outputs | Predicted class |
Multi-label classification | Raw model output | Raw model output | Recommended Sigmoid of the raw outputs | Class names sorted by score |
Regression | Raw model output | Transformed raw model output | Recommended Predicted value | Predicted value |
Recommendations | Raw model output | Log softmax of raw model output | Internal BaseModel codes of entities | Recommended A ranked list of recommended feature values |
- local_save_location: str
Default: None
If provided, points to the location in the local filesystem where predictions will be stored in TSV format. - remote_save_location: str
Default: None
If provided, defines a table in a remote database where the predictions will be stored. - limit_test_batches: int
Default: None If provided, defines how many of batches to run inference over. - top_k: int
Default: None Only for recommendation task. Number of top k values to recommend. It is highly advised to use this to reduce the size of the prediction file.
Additionally, stored parameters provided earlier as part of YAML training_params
or during scenario model training as TrainingParams
can also be modified. This is useful e.g. if you want to change the device, modify the number of recommendation targets, or add a new callback. Please refer to this article for the full list of modifiable parameters.
Example |
---|
The example provided below demonstrates testing parameters. In addition to specifying the location for predictions, some parameters have been overwritten or added. Note that the metric added as an argument requires importing an additional class.
from monad.config import MetricParams
from monad.ui.config import OutputType
from monad.ui.metrics import MultipleTargetsRecall
testing_params = TestingParams(
local_save_location="/path/where/predictions/should/be/stored.tsv",
output_type=OutputType.DECODED,
metrics=[MetricParams(alias="recall",metric_name="MultipleTargetsRecall"},
devices=[0],
)
Running predictions
Having instantiated the testing module and configured the testing parameters, we are now ready to start the inference run. This is done by simply calling the predict()
method of the testing module and providing testing_params
as its argument.
testing_module.predict(testing_params = testing_params)
Updated about 1 month ago