Interpreting your model's predictions

Understanding your prediction drivers

BaseModel provides possibility to demystify black-box model and provide interpretability of the model.

We're using Integrated Gradients method from this paper to provide explainable AI in Base Model.

Running Interpretability

In order to run interpretability for a model, you only need to run a python script below, adjusting some parameters as needed(the example below generates results for our HM Kaggle example):

from pathlib import Path

from monad.ui.interpretability import interpret, TreemapGenerator, TreemapHierarchy

if __name__ == "__main__":
    results_path = Path("/home/<USERNAME>/code/interpretability/results/")
    interpret(
        predictions_path=Path("/data1/<USERNAME>/hm/churn_preds.csv"),
        output_path=Path(results_path, "hm", "churn"),
        checkpoint_path=Path("/data1/<USERNAME>/hm/churn"),
        target_index=0,
    )

    tg_from_files = TreemapGenerator(
        interpretability_files_path=Path(results_path, "hm", "churn"),
    )
    tg_from_files.plot_treemap(
        output_file_path=Path(results_path, "hm", "churn_treemap.html"),
        n_largest_per_feature=200,
        n_largest=10000,
    )

    # Treemap with predefined hierarchy
    hierarchy = TreemapHierarchy(
        hierarchy_path=Path(
            "/home/<USERNAME>/code/interpretability/monad/monad_interpretability_scripts/classification/hm_treemap_mapping.csv"
        ),
        levels=["department_name", "section_name", "colour_group_name", "article_id"],
        entity_name_column="prod_name",
        feature_values_importance_path=Path(
            results_path, "hm", "churn", "transactions", "article_id", "values_importance.json"
        ),
    )
    tg_from_hierarchy = TreemapGenerator(hierarchy=hierarchy)
    tg_from_hierarchy.plot_treemap(
        output_file_path=Path(results_path, "hm", "churn_hierarchy_treemap.html"),
        n_largest=10000,
        n_largest_per_feature=None,
        max_depth=4,
    )


Important parameters:

  • predictions_path - path to your saved predictions
  • output_path - where do you want to store interpretability results
  • checkpoint_path - path to the model checkpoint - the model that was used to run predictions.
  • target_index - we provide interpretability with regards to a specific result, in this example churn - this code will give us more insight into clients, that did not churn.
  • limit batches- Number of batches to compute attributions. If None all batches will be used. Defaults to None.
  • resample- If data should be resampled to obtain balanced classes. Defaults to False.
  • resampling_class_size- Number of samples to take from each class. If a class has fewer observations, all available ones will be taken. Used if resampling is set to True. Defaults to 500.
  • Currently the following models are supported: binary classification, multilabel classification and multiclass classification.

Understanding results

Understanding the results is quite simple and TreeMap allows for some interactivity when exploring the results.

Results are also provided in .json file and may be used in any other way as well.

Treemap without hierarchy

The Initial general view will look more or less like this:

It shows different features and their relative feature importance - magnitude and direction. It is possible to click any of the blocks to drill into more details, for example, if I click the price block I get the following view:

There are 2 important parts to the interpretation here:

  • attribution color - describes the direction of the attibution. Green means positive influence on final prediction, red means negative influence. For example, if we drill down into sales_channel_id:

We can see that transactions via sales_channel_id = 1 is attibuted with increased probability of churning in the future, while sales_channel_id = 2 - decreases probability of churn

  • attribution value- this is a percentage influence - how many percentage points does a given feature increase/decrease probability of churn.

Treemap with hierarchy

We can also define the hierarchy for the interpretability as stated above in the python snipped. For example, if we define it like this:

    levels=["department_name", "section_name", "colour_group_name", "article_id"],

We will get the following chart:

We are able to explore any of the blocks in more detail, for example Trousers and then Womens Casual:

The interpretation is the same as in TreeMap without hierarchy here:

attribution color - describes the direction of the attibution. Green means positive influence on final prediction, red means negative influence. For example we can see that transactions of light blue trousers attibuted negatively towards churn, while white ones attributed positively the probability of churn.

  • attribution value- this is a percentage influence - how many percentage points does a given type of product_name increase/decrease probability of churn.