Skip to content

Interpretability

BaseModel provides model interpretation via gradient-based attribution — either Integrated Gradients (deterministic, default) or GradientSHAP (stochastic, faster) — generating feature-level and event-level attribution analysis. This helps explain why the model makes specific predictions for each entity. The same interpret() entry point can also render a client-ready SHAP-library-style report (beeswarm, bar, heatmap, waterfall, force plots) alongside the standard outputs.

Functions

from monad.interpretability import (
    interpret,
    interpret_entity,
    attributions_to_shap_explanation,
    save_shap_report,
)
from monad.interpretability.treemap import TreemapGenerator

interpret()

Generate aggregate interpretability analysis across many entities. Produces feature importance plots and JSON summaries.

from pathlib import Path
from datetime import datetime, timezone
from monad.interpretability import interpret

interpret(
    predictions_path=Path("./predictions.tsv"),
    output_path=Path("./interpretations"),
    checkpoint_path=Path("./my_model"),
    device="cuda",
)

Parameters

Parameter Type Default Description
predictions_path Path required Path to TSV file with model predictions.
output_path Path required Directory to store interpretation results.
checkpoint_path Path required Path to the scenario model checkpoint.
device str required Device for computation: "cpu" or "cuda".
prediction_date datetime datetime.now(tz=timezone.utc) Date for which predictions are interpreted. Uses UTC timezone.
limit_batches int \| None None Limit number of batches to compute attributions. None = all batches.
target_index int \| None None Output index for gradients (multiclass/multilabel). Not needed for regression.
classification_resample bool False Resample data for balanced classes. Only for classification models.
recommended_value str \| None None Item ID to interpret. Only for recommendation models.
group_size int 500 Maximum samples per group. For classification: per class. For recommendation: total observations.
method Literal["integrated_gradients", "gradient_shap"] "integrated_gradients" Attribution method. "integrated_gradients" runs a deterministic path integral; "gradient_shap" averages gradients across n_samples Gaussian-perturbed baselines (stochastic, typically faster, denser attributions). Both operate on the input-space Cleora/EMDE float sketches. See Attribution Methods.
save_shap_plots bool False If True, additionally render a SHAP-library-style report (beeswarm / bar / heatmap / waterfall / force + static shap_report.html index) under <output_path>/shap/. Requires the shap optional extra. See SHAP-Style Report.

Output Files

The output_path directory will contain a nested structure:

output_path/
├── source_importance.json          # Attribution scores per data source
├── source_importance.png           # Bar chart of data source importance
├── {data_source}/                  # One directory per data source
│   ├── feature_importance.json     # Attribution scores per feature
│   ├── feature_importance.png      # Bar chart of feature importance
│   └── {feature}/                  # One directory per feature
│       ├── values_importance.json      # Attribution scores per feature value
│       ├── values_highest_importance.png  # Top positive attributions
│       └── values_lowest_importance.png   # Top negative attributions
└── shap/                           # Only when save_shap_plots=True
    ├── shap_beeswarm.png           # Per-feature signed-distribution beeswarm
    ├── shap_bar_global.png         # Global mean-absolute attribution bar
    ├── shap_heatmap.png            # Per-entity × per-feature heatmap (n_entities ≥ 10)
    ├── shap_waterfall_top{i}.png   # Top-N entity waterfall plots
    ├── shap_force_top{i}.html      # Top-N entity interactive force plots
    └── shap_report.html            # Static index linking all artifacts

Attribution Methods

interpret() supports two gradient-based attribution methods via the method parameter. Both operate on the same input — the pre-computed Cleora/EMDE sketches monad consumes — and emit attributions in the same List[torch.Tensor] shape, so the downstream output structure is identical.

Method method= value Determinism Typical runtime Output density Recommended for
Integrated Gradients "integrated_gradients" Deterministic (path integral) Baseline Sparser Reproducible single runs, regulated reporting
GradientSHAP "gradient_shap" Stochastic (seeded; reproducible per-call) ~2× faster on representative workloads Denser Production-scale attribution, SHAP-style downstream plots

Switch methods with one keyword:

from monad.interpretability import interpret

interpret(
    predictions_path=Path("./predictions.tsv"),
    output_path=Path("./interpretations"),
    checkpoint_path=Path("./my_model"),
    device="cuda",
    method="gradient_shap",
)

GradientSHAP-specific knobs (n_samples, stdevs, n_baselines, seed) are not exposed through interpret() — to tune them, instantiate GradientShapInterpreter directly.


SHAP-Style Report

When save_shap_plots=True, interpret() writes a parallel shap/ subdirectory with a SHAP-library-style report (beeswarm, global bar, heatmap, per-entity waterfall and interactive force plots), plus a static shap_report.html index that links everything. The report is designed for client hand-off and renders without a Jupyter runtime.

from monad.interpretability import interpret

interpret(
    predictions_path=Path("./predictions.tsv"),
    output_path=Path("./interpretations"),
    checkpoint_path=Path("./my_model"),
    device="cuda",
    method="gradient_shap",
    save_shap_plots=True,
)

Optional dependency

The SHAP report requires the shap library, shipped in the interpretability extra. Install with poetry install -E interpretability or pip install '.[interpretability]'. Without it, interpret() raises ImportError only when save_shap_plots=True; the default attribution flow remains unaffected.

For rendering the report from already-computed attributions (e.g., a custom batch pipeline), use the two helpers below directly.

attributions_to_shap_explanation()

Convert monad per-entity attribution tensors into a shap.Explanation object that can be passed to any shap.plots.* function.

from monad.interpretability import attributions_to_shap_explanation

explanation = attributions_to_shap_explanation(
    attributions=attributions,
    feature_slicer=explainer.feature_slicer,
    base_value=base_value,
)
Parameter Type Default Description
attributions list[torch.Tensor] required Per-modality attribution tensors with shape (n_entities, ...) on dim 0. Same format returned by GradientShapInterpreter.get_attributions().
feature_slicer FeatureSlicer required Pre-built slicer. Pass explainer.feature_slicer to avoid rebuilding.
base_value float required Scalar E[f(baseline)] — the model output at the reference baseline.
output_name str \| None None Optional label stamped onto explanation.output_names (e.g. "target_0" for classification).

Returns a shap.Explanation with values of shape (n_entities, n_features), signed-summed per feature slice, and a stable feature-name order keyed by "<data_source>.<feature>" to avoid cross-source name collisions.

Raises ValueError if attributions is empty, contains no entities, or the slicer has no slices; ImportError if shap is not installed.

save_shap_report()

Render the full SHAP-style report from a shap.Explanation.

from monad.interpretability import save_shap_report

save_shap_report(
    explanation=explanation,
    output_path=Path("./interpretations"),
    top_n_waterfalls=5,
)
Parameter Type Default Description
explanation shap.Explanation required A SHAP Explanation, typically from attributions_to_shap_explanation().
output_path Path required Parent directory; a shap/ subdirectory is created under it.
top_n_waterfalls int 5 How many highest-impact entities to render waterfall and force plots for (capped at n_entities). Kept small by default because each force HTML embeds ~2 MB of shap.js.

Returns the Path to the created shap/ directory. Writes shap_beeswarm.png, shap_bar_global.png, shap_heatmap.png (only when n_entities >= 10), shap_waterfall_top{i}.png, shap_force_top{i}.html, and shap_report.html. Individual plot failures are logged and skipped rather than aborting the whole report.

Raises ImportError if shap or matplotlib is not installed.


interpret_entity()

Explain a single entity's prediction at the event level. Produces a JSON file with per-event, per-feature attributions.

from pathlib import Path
from datetime import datetime, timezone
from monad.interpretability import interpret_entity

interpret_entity(
    output_path=Path("./interpretations/customer_123.json"),
    checkpoint_path=Path("./my_model"),
    predictions_path=Path("./predictions.tsv"),
    main_entity_id="customer_123",
    device="cuda",
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)

Parameters

Parameter Type Default Description
output_path Path required Path for the output JSON file.
checkpoint_path Path required Path to the scenario model checkpoint.
predictions_path Path required Path to TSV file with model predictions.
main_entity_id str required Entity ID to interpret. For some databases (e.g., Snowflake), the value may need escaping.
device str required Device: "cpu" or "cuda".
prediction_date datetime datetime.now(tz=timezone.utc) Date for which to interpret. Uses UTC timezone.
target_index int \| None None Output index for gradients (multiclass/multilabel).
recommended_value str \| None None Item ID to interpret (recommendation only).

Output Format

The output JSON contains event-level attributions:

{
    "transactions": [
        {
            "timestamp": "25-01-2020, 00:00:00",
            "modality_attributions": [
                {
                    "data_source_name": "transactions",
                    "name": "article_id",
                    "value": "0854796002",
                    "attribution": -0.124
                },
                {
                    "data_source_name": "transactions",
                    "name": "price",
                    "value": 0.017,
                    "attribution": -0.009
                }
            ]
        }
    ]
}

Batch Processing Multiple Entities

from pathlib import Path
from datetime import datetime, timezone
from monad.interpretability import interpret_entity

entities_to_explain = ["cust_001", "cust_002", "cust_003"]

for entity_id in entities_to_explain:
    interpret_entity(
        output_path=Path(f"./interpretations/{entity_id}.json"),
        checkpoint_path=Path("./my_model"),
        predictions_path=Path("./predictions.tsv"),
        main_entity_id=entity_id,
        device="cuda",
        prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    )

TreemapGenerator

Generate treemap visualizations from interpretability attribution data. Treemaps provide a hierarchical view of feature importance across data sources and features.

from monad.interpretability.treemap import TreemapGenerator

Constructor

generator = TreemapGenerator(
    interpretability_files_path=Path("./interpretations"),
)
Parameter Type Default Description
interpretability_files_path Path \| None None Path to directory with interpretation output (from interpret()).
hierarchy TreemapHierarchy \| None None Custom hierarchy definition for treemap levels.

Note

Exactly one of interpretability_files_path or hierarchy must be provided.

plot_treemap()

Generate and save a treemap chart as an HTML file.

generator.plot_treemap(
    output_file_path=Path("./treemap.html"),
    n_largest_per_feature=1500,
    max_depth=3,
)
Parameter Type Default Description
output_file_path Path required Path for the output HTML file.
n_largest_per_feature int \| None 1500 Maximum number of values per feature to include.
n_largest int \| None None Maximum total number of values to include.
exclude_positive_attributions bool False Exclude features with positive attributions.
exclude_negative_attributions bool False Exclude features with negative attributions.
max_depth int 3 Maximum depth of the treemap hierarchy.

TreemapHierarchy

For custom treemap hierarchies:

from monad.interpretability.treemap import TreemapHierarchy

hierarchy = TreemapHierarchy(
    levels=["category", "brand", "product_id"],
    hierarchy_path=Path("./data/product_hierarchy.csv"),
    feature_values_importance_path=Path("./interpretations/transactions/article_id/values_importance.json"),
    entity_name_column="product_name",  # Optional
)

generator = TreemapGenerator(hierarchy=hierarchy)
generator.plot_treemap(output_file_path=Path("./custom_treemap.html"))
Field Type Default Description
levels list[str] required Hierarchy level names for the treemap.
hierarchy_path Path required Path to a CSV file defining the hierarchy. Each row maps the last level to higher levels.
feature_values_importance_path Path required Path to the values importance JSON file.
entity_name_column str \| None None Column name for entity labels in the visualization.

GradientShapInterpreter

For advanced or scripted attribution outside interpret() — batched runs, custom baselines, or feeding attributions straight into attributions_to_shap_explanation() — instantiate an interpreter directly. The package exports three task-specific variants:

Class Task
ClassificationGradientShapInterpreter BinaryClassificationTask, MulticlassClassificationTask, MultilabelClassificationTask
RecommendationGradientShapInterpreter BaseRecommendationTask
RegressionGradientShapInterpreter RegressionTask
from monad.interpretability import ClassificationGradientShapInterpreter

Constructor

interpreter = ClassificationGradientShapInterpreter(
    training_module=training_module,
    checkpoint_path=Path("./my_model"),
    predictions_path=Path("./predictions.tsv"),
    resample=False,
    group_size=500,
    n_samples=25,
    stdevs=0.15,
    n_baselines=100,
    seed=42,
)
Parameter Type Default Description
n_samples int 25 Number of perturbed samples to average gradients over. Must be >= 1. Higher values reduce attribution variance at proportional runtime cost.
stdevs float 0.15 Standard deviation of the Gaussian noise added to each sampled baseline. Must be > 0; 0 would collapse GradientSHAP onto the deterministic IG path.
n_baselines int 100 Size of the background reference distribution, sampled uniformly at random across the whole predict set via reservoir sampling. Baselines are drawn from real observations rather than an all-zero point, keeping attributions on-manifold. Bounds only the memory held for the reference set; runtime is governed by n_samples. Must be >= 1.
seed int 42 RNG seed for reproducibility of the stochastic baseline draws. Re-applied at the start of every get_attributions() call, so outputs are reproducible across invocations on the same interpreter.

Remaining positional and keyword arguments (training_module, checkpoint_path, predictions_path, resample, group_size, recommended_value, main_entity_id) are forwarded to the underlying Integrated Gradients base class.

Baseline distribution

GradientSHAP draws its baseline from a real background distribution sampled across the predict set. Because the baseline is no longer a fixed all-zero point, attributions are stochastic and seed-dependent, and will differ from releases prior to 1.7.

get_attributions()

Run attribution and return the raw per-modality tensors, suitable for downstream tools such as attributions_to_shap_explanation().

attributions = interpreter.get_attributions(
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
    target=0,
    limit_batches=None,
    device="cuda",
)
Parameter Type Default Description
prediction_date datetime required Date for which the predictions should be interpreted.
target int \| None required Output index for which gradients are computed. None for scalar-output tasks (regression).
limit_batches int \| None required Number of batches to compute attributions on. None = all batches.
device str required Device for computation: "cpu" or "cuda".

Returns a list[torch.Tensor] — one tensor per input modality, with shape (n_entities, ...) on dim 0.


Complete Workflow Example

from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.config import TestingParams, OutputType
from monad.interpretability import interpret, interpret_entity
from monad.interpretability.treemap import TreemapGenerator

# 1. Generate predictions with attributions enabled
module = load_from_checkpoint(Path("./churn_model"))

testing_params = TestingParams(
    output_type=OutputType.SEMANTIC,
    devices=[0],
    local_save_location=Path("./predictions.tsv"),
)
module.predict(testing_params)

# 2. Global feature importance
interpret(
    predictions_path=Path("./predictions.tsv"),
    output_path=Path("./interpretations"),
    checkpoint_path=Path("./churn_model"),
    device="cuda",
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)

# 3. Treemap visualization
generator = TreemapGenerator(
    interpretability_files_path=Path("./interpretations"),
)
generator.plot_treemap(
    output_file_path=Path("./treemap.html"),
)

# 4. Single entity deep-dive
interpret_entity(
    output_path=Path("./interpretations/customer_123.json"),
    checkpoint_path=Path("./churn_model"),
    predictions_path=Path("./predictions.tsv"),
    main_entity_id="customer_123",
    device="cuda",
    prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)

# 5. SHAP-style report for client hand-off
interpret(
    predictions_path=Path("./predictions.tsv"),
    output_path=Path("./interpretations"),
    checkpoint_path=Path("./churn_model"),
    device="cuda",
    method="gradient_shap",
    save_shap_plots=True,
)