Interpretability
BaseModel provides model interpretation via gradient-based attribution — either Integrated Gradients (deterministic, default) or GradientSHAP (stochastic, faster) — generating feature-level and event-level attribution analysis. This helps explain why the model makes specific predictions for each entity. The same interpret() entry point can also render a client-ready SHAP-library-style report (beeswarm, bar, heatmap, waterfall, force plots) alongside the standard outputs.
Functions
from monad.interpretability import (
interpret,
interpret_entity,
attributions_to_shap_explanation,
save_shap_report,
)
from monad.interpretability.treemap import TreemapGenerator
interpret()
Generate aggregate interpretability analysis across many entities. Produces feature importance plots and JSON summaries.
from pathlib import Path
from datetime import datetime, timezone
from monad.interpretability import interpret
interpret(
predictions_path=Path("./predictions.tsv"),
output_path=Path("./interpretations"),
checkpoint_path=Path("./my_model"),
device="cuda",
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
predictions_path |
Path |
required | Path to TSV file with model predictions. |
output_path |
Path |
required | Directory to store interpretation results. |
checkpoint_path |
Path |
required | Path to the scenario model checkpoint. |
device |
str |
required | Device for computation: "cpu" or "cuda". |
prediction_date |
datetime |
datetime.now(tz=timezone.utc) |
Date for which predictions are interpreted. Uses UTC timezone. |
limit_batches |
int \| None |
None |
Limit number of batches to compute attributions. None = all batches. |
target_index |
int \| None |
None |
Output index for gradients (multiclass/multilabel). Not needed for regression. |
classification_resample |
bool |
False |
Resample data for balanced classes. Only for classification models. |
recommended_value |
str \| None |
None |
Item ID to interpret. Only for recommendation models. |
group_size |
int |
500 |
Maximum samples per group. For classification: per class. For recommendation: total observations. |
method |
Literal["integrated_gradients", "gradient_shap"] |
"integrated_gradients" |
Attribution method. "integrated_gradients" runs a deterministic path integral; "gradient_shap" averages gradients across n_samples Gaussian-perturbed baselines (stochastic, typically faster, denser attributions). Both operate on the input-space Cleora/EMDE float sketches. See Attribution Methods. |
save_shap_plots |
bool |
False |
If True, additionally render a SHAP-library-style report (beeswarm / bar / heatmap / waterfall / force + static shap_report.html index) under <output_path>/shap/. Requires the shap optional extra. See SHAP-Style Report. |
Output Files
The output_path directory will contain a nested structure:
output_path/
├── source_importance.json # Attribution scores per data source
├── source_importance.png # Bar chart of data source importance
├── {data_source}/ # One directory per data source
│ ├── feature_importance.json # Attribution scores per feature
│ ├── feature_importance.png # Bar chart of feature importance
│ └── {feature}/ # One directory per feature
│ ├── values_importance.json # Attribution scores per feature value
│ ├── values_highest_importance.png # Top positive attributions
│ └── values_lowest_importance.png # Top negative attributions
└── shap/ # Only when save_shap_plots=True
├── shap_beeswarm.png # Per-feature signed-distribution beeswarm
├── shap_bar_global.png # Global mean-absolute attribution bar
├── shap_heatmap.png # Per-entity × per-feature heatmap (n_entities ≥ 10)
├── shap_waterfall_top{i}.png # Top-N entity waterfall plots
├── shap_force_top{i}.html # Top-N entity interactive force plots
└── shap_report.html # Static index linking all artifacts
Attribution Methods
interpret() supports two gradient-based attribution methods via the method parameter. Both operate on the same input — the pre-computed Cleora/EMDE sketches monad consumes — and emit attributions in the same List[torch.Tensor] shape, so the downstream output structure is identical.
| Method | method= value |
Determinism | Typical runtime | Output density | Recommended for |
|---|---|---|---|---|---|
| Integrated Gradients | "integrated_gradients" |
Deterministic (path integral) | Baseline | Sparser | Reproducible single runs, regulated reporting |
| GradientSHAP | "gradient_shap" |
Stochastic (seeded; reproducible per-call) | ~2× faster on representative workloads | Denser | Production-scale attribution, SHAP-style downstream plots |
Switch methods with one keyword:
from monad.interpretability import interpret
interpret(
predictions_path=Path("./predictions.tsv"),
output_path=Path("./interpretations"),
checkpoint_path=Path("./my_model"),
device="cuda",
method="gradient_shap",
)
GradientSHAP-specific knobs (n_samples, stdevs, n_baselines, seed) are not exposed through interpret() — to tune them, instantiate GradientShapInterpreter directly.
SHAP-Style Report
When save_shap_plots=True, interpret() writes a parallel shap/ subdirectory with a SHAP-library-style report (beeswarm, global bar, heatmap, per-entity waterfall and interactive force plots), plus a static shap_report.html index that links everything. The report is designed for client hand-off and renders without a Jupyter runtime.
from monad.interpretability import interpret
interpret(
predictions_path=Path("./predictions.tsv"),
output_path=Path("./interpretations"),
checkpoint_path=Path("./my_model"),
device="cuda",
method="gradient_shap",
save_shap_plots=True,
)
Optional dependency
The SHAP report requires the shap library, shipped in the interpretability extra. Install with poetry install -E interpretability or pip install '.[interpretability]'. Without it, interpret() raises ImportError only when save_shap_plots=True; the default attribution flow remains unaffected.
For rendering the report from already-computed attributions (e.g., a custom batch pipeline), use the two helpers below directly.
attributions_to_shap_explanation()
Convert monad per-entity attribution tensors into a shap.Explanation object that can be passed to any shap.plots.* function.
from monad.interpretability import attributions_to_shap_explanation
explanation = attributions_to_shap_explanation(
attributions=attributions,
feature_slicer=explainer.feature_slicer,
base_value=base_value,
)
| Parameter | Type | Default | Description |
|---|---|---|---|
attributions |
list[torch.Tensor] |
required | Per-modality attribution tensors with shape (n_entities, ...) on dim 0. Same format returned by GradientShapInterpreter.get_attributions(). |
feature_slicer |
FeatureSlicer |
required | Pre-built slicer. Pass explainer.feature_slicer to avoid rebuilding. |
base_value |
float |
required | Scalar E[f(baseline)] — the model output at the reference baseline. |
output_name |
str \| None |
None |
Optional label stamped onto explanation.output_names (e.g. "target_0" for classification). |
Returns a shap.Explanation with values of shape (n_entities, n_features), signed-summed per feature slice, and a stable feature-name order keyed by "<data_source>.<feature>" to avoid cross-source name collisions.
Raises ValueError if attributions is empty, contains no entities, or the slicer has no slices; ImportError if shap is not installed.
save_shap_report()
Render the full SHAP-style report from a shap.Explanation.
from monad.interpretability import save_shap_report
save_shap_report(
explanation=explanation,
output_path=Path("./interpretations"),
top_n_waterfalls=5,
)
| Parameter | Type | Default | Description |
|---|---|---|---|
explanation |
shap.Explanation |
required | A SHAP Explanation, typically from attributions_to_shap_explanation(). |
output_path |
Path |
required | Parent directory; a shap/ subdirectory is created under it. |
top_n_waterfalls |
int |
5 |
How many highest-impact entities to render waterfall and force plots for (capped at n_entities). Kept small by default because each force HTML embeds ~2 MB of shap.js. |
Returns the Path to the created shap/ directory. Writes shap_beeswarm.png, shap_bar_global.png, shap_heatmap.png (only when n_entities >= 10), shap_waterfall_top{i}.png, shap_force_top{i}.html, and shap_report.html. Individual plot failures are logged and skipped rather than aborting the whole report.
Raises ImportError if shap or matplotlib is not installed.
interpret_entity()
Explain a single entity's prediction at the event level. Produces a JSON file with per-event, per-feature attributions.
from pathlib import Path
from datetime import datetime, timezone
from monad.interpretability import interpret_entity
interpret_entity(
output_path=Path("./interpretations/customer_123.json"),
checkpoint_path=Path("./my_model"),
predictions_path=Path("./predictions.tsv"),
main_entity_id="customer_123",
device="cuda",
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
output_path |
Path |
required | Path for the output JSON file. |
checkpoint_path |
Path |
required | Path to the scenario model checkpoint. |
predictions_path |
Path |
required | Path to TSV file with model predictions. |
main_entity_id |
str |
required | Entity ID to interpret. For some databases (e.g., Snowflake), the value may need escaping. |
device |
str |
required | Device: "cpu" or "cuda". |
prediction_date |
datetime |
datetime.now(tz=timezone.utc) |
Date for which to interpret. Uses UTC timezone. |
target_index |
int \| None |
None |
Output index for gradients (multiclass/multilabel). |
recommended_value |
str \| None |
None |
Item ID to interpret (recommendation only). |
Output Format
The output JSON contains event-level attributions:
{
"transactions": [
{
"timestamp": "25-01-2020, 00:00:00",
"modality_attributions": [
{
"data_source_name": "transactions",
"name": "article_id",
"value": "0854796002",
"attribution": -0.124
},
{
"data_source_name": "transactions",
"name": "price",
"value": 0.017,
"attribution": -0.009
}
]
}
]
}
Batch Processing Multiple Entities
from pathlib import Path
from datetime import datetime, timezone
from monad.interpretability import interpret_entity
entities_to_explain = ["cust_001", "cust_002", "cust_003"]
for entity_id in entities_to_explain:
interpret_entity(
output_path=Path(f"./interpretations/{entity_id}.json"),
checkpoint_path=Path("./my_model"),
predictions_path=Path("./predictions.tsv"),
main_entity_id=entity_id,
device="cuda",
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)
TreemapGenerator
Generate treemap visualizations from interpretability attribution data. Treemaps provide a hierarchical view of feature importance across data sources and features.
Constructor
| Parameter | Type | Default | Description |
|---|---|---|---|
interpretability_files_path |
Path \| None |
None |
Path to directory with interpretation output (from interpret()). |
hierarchy |
TreemapHierarchy \| None |
None |
Custom hierarchy definition for treemap levels. |
Note
Exactly one of interpretability_files_path or hierarchy must be provided.
plot_treemap()
Generate and save a treemap chart as an HTML file.
generator.plot_treemap(
output_file_path=Path("./treemap.html"),
n_largest_per_feature=1500,
max_depth=3,
)
| Parameter | Type | Default | Description |
|---|---|---|---|
output_file_path |
Path |
required | Path for the output HTML file. |
n_largest_per_feature |
int \| None |
1500 |
Maximum number of values per feature to include. |
n_largest |
int \| None |
None |
Maximum total number of values to include. |
exclude_positive_attributions |
bool |
False |
Exclude features with positive attributions. |
exclude_negative_attributions |
bool |
False |
Exclude features with negative attributions. |
max_depth |
int |
3 |
Maximum depth of the treemap hierarchy. |
TreemapHierarchy
For custom treemap hierarchies:
from monad.interpretability.treemap import TreemapHierarchy
hierarchy = TreemapHierarchy(
levels=["category", "brand", "product_id"],
hierarchy_path=Path("./data/product_hierarchy.csv"),
feature_values_importance_path=Path("./interpretations/transactions/article_id/values_importance.json"),
entity_name_column="product_name", # Optional
)
generator = TreemapGenerator(hierarchy=hierarchy)
generator.plot_treemap(output_file_path=Path("./custom_treemap.html"))
| Field | Type | Default | Description |
|---|---|---|---|
levels |
list[str] |
required | Hierarchy level names for the treemap. |
hierarchy_path |
Path |
required | Path to a CSV file defining the hierarchy. Each row maps the last level to higher levels. |
feature_values_importance_path |
Path |
required | Path to the values importance JSON file. |
entity_name_column |
str \| None |
None |
Column name for entity labels in the visualization. |
GradientShapInterpreter
For advanced or scripted attribution outside interpret() — batched runs, custom baselines, or feeding attributions straight into attributions_to_shap_explanation() — instantiate an interpreter directly. The package exports three task-specific variants:
| Class | Task |
|---|---|
ClassificationGradientShapInterpreter |
BinaryClassificationTask, MulticlassClassificationTask, MultilabelClassificationTask |
RecommendationGradientShapInterpreter |
BaseRecommendationTask |
RegressionGradientShapInterpreter |
RegressionTask |
Constructor
interpreter = ClassificationGradientShapInterpreter(
training_module=training_module,
checkpoint_path=Path("./my_model"),
predictions_path=Path("./predictions.tsv"),
resample=False,
group_size=500,
n_samples=25,
stdevs=0.15,
n_baselines=100,
seed=42,
)
| Parameter | Type | Default | Description |
|---|---|---|---|
n_samples |
int |
25 |
Number of perturbed samples to average gradients over. Must be >= 1. Higher values reduce attribution variance at proportional runtime cost. |
stdevs |
float |
0.15 |
Standard deviation of the Gaussian noise added to each sampled baseline. Must be > 0; 0 would collapse GradientSHAP onto the deterministic IG path. |
n_baselines |
int |
100 |
Size of the background reference distribution, sampled uniformly at random across the whole predict set via reservoir sampling. Baselines are drawn from real observations rather than an all-zero point, keeping attributions on-manifold. Bounds only the memory held for the reference set; runtime is governed by n_samples. Must be >= 1. |
seed |
int |
42 |
RNG seed for reproducibility of the stochastic baseline draws. Re-applied at the start of every get_attributions() call, so outputs are reproducible across invocations on the same interpreter. |
Remaining positional and keyword arguments (training_module, checkpoint_path, predictions_path, resample, group_size, recommended_value, main_entity_id) are forwarded to the underlying Integrated Gradients base class.
Baseline distribution
GradientSHAP draws its baseline from a real background distribution sampled across the predict set. Because the baseline is no longer a fixed all-zero point, attributions are stochastic and seed-dependent, and will differ from releases prior to 1.7.
get_attributions()
Run attribution and return the raw per-modality tensors, suitable for downstream tools such as attributions_to_shap_explanation().
attributions = interpreter.get_attributions(
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
target=0,
limit_batches=None,
device="cuda",
)
| Parameter | Type | Default | Description |
|---|---|---|---|
prediction_date |
datetime |
required | Date for which the predictions should be interpreted. |
target |
int \| None |
required | Output index for which gradients are computed. None for scalar-output tasks (regression). |
limit_batches |
int \| None |
required | Number of batches to compute attributions on. None = all batches. |
device |
str |
required | Device for computation: "cpu" or "cuda". |
Returns a list[torch.Tensor] — one tensor per input modality, with shape (n_entities, ...) on dim 0.
Complete Workflow Example
from pathlib import Path
from datetime import datetime, timezone
from monad.ui.module import load_from_checkpoint
from monad.config import TestingParams, OutputType
from monad.interpretability import interpret, interpret_entity
from monad.interpretability.treemap import TreemapGenerator
# 1. Generate predictions with attributions enabled
module = load_from_checkpoint(Path("./churn_model"))
testing_params = TestingParams(
output_type=OutputType.SEMANTIC,
devices=[0],
local_save_location=Path("./predictions.tsv"),
)
module.predict(testing_params)
# 2. Global feature importance
interpret(
predictions_path=Path("./predictions.tsv"),
output_path=Path("./interpretations"),
checkpoint_path=Path("./churn_model"),
device="cuda",
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)
# 3. Treemap visualization
generator = TreemapGenerator(
interpretability_files_path=Path("./interpretations"),
)
generator.plot_treemap(
output_file_path=Path("./treemap.html"),
)
# 4. Single entity deep-dive
interpret_entity(
output_path=Path("./interpretations/customer_123.json"),
checkpoint_path=Path("./churn_model"),
predictions_path=Path("./predictions.tsv"),
main_entity_id="customer_123",
device="cuda",
prediction_date=datetime(2024, 6, 1, tzinfo=timezone.utc),
)
# 5. SHAP-style report for client hand-off
interpret(
predictions_path=Path("./predictions.tsv"),
output_path=Path("./interpretations"),
checkpoint_path=Path("./churn_model"),
device="cuda",
method="gradient_shap",
save_shap_plots=True,
)