Implementation:Pytorch Serve Captum Explanations
| Field | Value |
|---|---|
| Page Type | Implementation |
| Title | Captum Explanations |
| Type | API Doc |
| Short Description | Implementation of model explainability using Captum Layer Integrated Gradients to compute per-token importance scores for HuggingFace Transformer predictions |
| Domains | NLP, Explainability |
| Source | examples/Huggingface_Transformers/Transformer_handler_generalized.py:L357-513 |
| Knowledge Sources | TorchServe |
| Workflow | HuggingFace_Transformer_Serving |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
The Captum Explanations implementation provides model interpretability for HuggingFace Transformer predictions served through TorchServe. It uses Captum's LayerIntegratedGradients algorithm applied to the model's embedding layer to compute per-token attribution scores. The implementation consists of the get_insights() method on the handler class and four supporting module-level functions that handle input construction, forward pass wrapping, attribution summarization, and token extraction.
Description
The implementation computes how much each input token contributes to the model's prediction for a specified target class. It supports both classification tasks (sequence and token classification) and question answering, producing different output formats for each.
Usage
Explanations are triggered through TorchServe's explanation endpoint:
# Sequence classification explanation
curl -X POST http://localhost:8080/explanations/bert_model \
-H "Content-Type: text/plain" \
-d '{"text": "This movie was fantastic", "target": 1}'
# Question answering explanation
curl -X POST http://localhost:8080/explanations/bert_qa \
-H "Content-Type: text/plain" \
-d '{"question": "Who founded Tesla?", "context": "Tesla was founded by Elon Musk.", "target": 0}'
Requires captum_explanation: true and embedding_name set in model-config.yaml.
Code Reference
Source Location
| Field | Value |
|---|---|
| Repository | pytorch/serve |
| File | examples/Huggingface_Transformers/Transformer_handler_generalized.py
|
| Lines | L357-513 |
Method and Function Signatures
get_insights(self, input_batch, text, target) (L357-427)
def get_insights(self, input_batch, text, target):
"""This function initialize and calls the layer integrated gradient to get word importance
of the input text if captum explanation has been selected through setup_config
Args:
input_batch (int): Batches of tokens IDs of text
text (str): The Text specified in the input request
target (int): The Target can be set to any acceptable label under the user's discretion.
Returns:
(list): Returns a list of importances and words.
"""
Main method on TransformersSeqClassifierHandler that orchestrates the explanation computation:
- Initializes
LayerIntegratedGradientswithcaptum_sequence_forwardand the model's embedding layer (accessed viagetattr(self.model, self.setup_config["embedding_name"]).embeddings) - Parses input text to extract text content and target class
- Constructs input IDs, reference IDs, and attention mask via
construct_input_ref() - Extracts word tokens via
get_word_token() - For sequence/token classification: runs
lig.attribute()once for the target class - For question answering: runs
lig.attribute()twice (position=0 for start, position=1 for end) - Summarizes attributions and returns response dictionary
construct_input_ref(text, tokenizer, device, mode) (L430-466)
def construct_input_ref(text, tokenizer, device, mode):
"""For a given text, this function creates token id, reference id and
attention mask based on encode which is faster for captum insights
Args:
text (str): The text specified in the input request
tokenizer (AutoTokenizer Class Object): To word tokenize the input text
device (cpu or gpu): Type of the Environment the server runs on.
Returns:
input_id(Tensor): It attributes to the tensor of the input tokenized words
ref_input_ids(Tensor): Ref Input IDs are used as baseline for the attributions
attention mask() : The attention mask is a binary tensor indicating the position
of the padded indices so that the model does not attend to them.
"""
Constructs three tensors required for integrated gradients:
- input_ids:
[CLS] + text_token_ids + [SEP]as a tensor - ref_input_ids:
[CLS] + [PAD]*len(text_tokens) + [SEP]as a tensor (the baseline) - attention_mask: All ones, same shape as input_ids
For question answering mode, encodes the question-context pair first, then falls through to standard processing.
captum_sequence_forward(inputs, attention_mask, position, model) (L469-485)
def captum_sequence_forward(inputs, attention_mask=None, position=0, model=None):
"""This function is used to get the predictions from the model and this function
can be used independent of the type of the BERT Task.
Args:
inputs (list): Input for Predictions
attention_mask (list, optional): The attention mask is a binary tensor indicating the position
of the padded indices so that the model does not attend to them, it defaults to None.
position (int, optional): Position depends on the BERT Task.
model ([type], optional): Name of the model, it defaults to None.
Returns:
list: Prediction Outcome
"""
Wrapper function that Captum calls during attribution computation:
- Sets model to eval mode and zeros gradients
- Runs
model(inputs, attention_mask=attention_mask) - Returns
pred[position]where position selects the output (0 = logits or start_logits, 1 = end_logits)
summarize_attributions(attributions) (L488-497)
def summarize_attributions(attributions):
"""Summarizes the attribution across multiple runs
Args:
attributions ([list): attributions from the Layer Integrated Gradients
Returns:
list : Returns the attributions after normalizing them.
"""
Reduces the per-dimension attributions to per-token scores:
- Sums across the embedding dimension:
attributions.sum(dim=-1).squeeze(0) - Normalizes by L2 norm:
attributions / torch.norm(attributions)
get_word_token(input_ids, tokenizer) (L500-513)
def get_word_token(input_ids, tokenizer):
"""constructs word tokens from token id using the BERT's
Auto Tokenizer
Args:
input_ids (list): Input IDs from construct_input_ref method
tokenizer (class): The Auto Tokenizer Pre-Trained model object
Returns:
(list): Returns the word tokens
"""
Converts token IDs back to readable string tokens:
- Extracts indices from the first batch element:
input_ids[0].detach().tolist() - Converts IDs to tokens:
tokenizer.convert_ids_to_tokens(indices) - Cleans BPE artifacts: removes the
"G"(Unicode space character from GPT-style tokenizers)
Import
import ast
import torch
from captum.attr import LayerIntegratedGradients
from transformers import AutoTokenizer
I/O Contract
Input
| Function | Input | Type | Description |
|---|---|---|---|
| get_insights | input_batch | tuple(Tensor, Tensor) | Token IDs and attention mask from preprocess |
| get_insights | text | str | Raw input text, formatted as {"text": "...", "target": N}
|
| get_insights | target | int | Target class index for attribution |
| construct_input_ref | text | str | The text string to tokenize |
| construct_input_ref | tokenizer | AutoTokenizer | The loaded tokenizer object |
| construct_input_ref | device | torch.device | CPU or GPU device |
| construct_input_ref | mode | str | NLP task mode string |
| captum_sequence_forward | inputs | Tensor | Embedding-level inputs |
| captum_sequence_forward | attention_mask | Tensor | Attention mask tensor |
| captum_sequence_forward | position | int | Output position index (0 or 1) |
| captum_sequence_forward | model | nn.Module | The Transformer model |
| summarize_attributions | attributions | Tensor | Raw attribution tensor from LIG |
| get_word_token | input_ids | Tensor | Token ID tensor |
| get_word_token | tokenizer | AutoTokenizer | The loaded tokenizer object |
Output
Sequence Classification / Token Classification:
[{
"words": ["[CLS]", "this", "movie", "was", "great", "[SEP]"],
"importances": [0.01, 0.05, 0.15, 0.02, 0.77, 0.00],
"delta": 0.0012
}]
Question Answering:
[{
"words": ["[CLS]", "who", "founded", "tesla", "?", "[SEP]", "tesla", "was", "founded", "by", "elon", "musk", ".", "[SEP]"],
"importances_answer_start": [0.01, 0.05, 0.08, 0.03, 0.02, 0.00, 0.03, 0.02, 0.05, 0.10, 0.45, 0.15, 0.01, 0.00],
"importances_answer_end": [0.01, 0.03, 0.05, 0.02, 0.01, 0.00, 0.02, 0.01, 0.03, 0.08, 0.12, 0.55, 0.07, 0.00],
"delta_start": 0.0015,
"delta_end": 0.0023
}]
| Output Field | Type | Description |
|---|---|---|
| words | list of str | The input tokens including special tokens |
| importances | list of float | Normalized per-token attribution scores (seq/token classification) |
| importances_answer_start | list of float | Attributions for answer start position (QA) |
| importances_answer_end | list of float | Attributions for answer end position (QA) |
| delta | float | Convergence delta measuring approximation quality (seq/token classification) |
| delta_start | float | Convergence delta for answer start (QA) |
| delta_end | float | Convergence delta for answer end (QA) |
Usage Examples
Example 1: Sequence Classification Explanation
# Input to get_insights:
# text = '{"text": "Bloomberg has decided to publish a new report on global warming.", "target": 1}'
# Internal flow:
embedding_layer = getattr(self.model, "bert") # from setup_config["embedding_name"]
embeddings = embedding_layer.embeddings
self.lig = LayerIntegratedGradients(captum_sequence_forward, embeddings)
input_ids, ref_input_ids, attention_mask = construct_input_ref(
text="Bloomberg has decided to publish a new report on global warming.",
tokenizer=self.tokenizer,
device=self.device,
mode="sequence_classification"
)
# input_ids: [101, 23280, 2038, 2787, ..., 102] (CLS + tokens + SEP)
# ref_input_ids: [101, 0, 0, 0, ..., 102] (CLS + PADs + SEP)
attributions, delta = self.lig.attribute(
inputs=input_ids,
baselines=ref_input_ids,
target=1,
additional_forward_args=(attention_mask, 0, self.model),
return_convergence_delta=True,
)
attributions_sum = summarize_attributions(attributions)
# Output: [{"words": ["[CLS]", "bloomberg", "has", ...], "importances": [...], "delta": 0.001}]
Example 2: Question Answering Explanation
# Input to get_insights:
# text = '{"question": "Who is CEO?", "context": "Elon Musk is CEO of Tesla.", "target": 0}'
# Runs LIG twice:
# 1. position=0 for answer start attributions
attributions_start, delta_start = self.lig.attribute(
inputs=input_ids, baselines=ref_input_ids, target=0,
additional_forward_args=(attention_mask, 0, self.model),
return_convergence_delta=True,
)
# 2. position=1 for answer end attributions
attributions_end, delta_end = self.lig.attribute(
inputs=input_ids, baselines=ref_input_ids, target=0,
additional_forward_args=(attention_mask, 1, self.model),
return_convergence_delta=True,
)
# Output: [{"words": [...], "importances_answer_start": [...],
# "importances_answer_end": [...], "delta_start": ..., "delta_end": ...}]
Example 3: Using construct_input_ref
input_ids, ref_input_ids, attention_mask = construct_input_ref(
text="This is a test sentence.",
tokenizer=tokenizer,
device=torch.device("cpu"),
mode="sequence_classification"
)
# input_ids shape: (1, num_tokens + 2) -- includes CLS and SEP
# ref_input_ids shape: (1, num_tokens + 2) -- PAD tokens in place of content
# attention_mask shape: (1, num_tokens + 2) -- all ones
Example 4: Using summarize_attributions
# Raw attributions shape: (1, seq_len, embedding_dim)
# e.g., (1, 10, 768) for a 10-token input with BERT
attributions_sum = summarize_attributions(attributions)
# Step 1: sum(dim=-1) -> (1, 10) -> squeeze(0) -> (10,)
# Step 2: normalize by L2 norm -> (10,) with unit norm
# Result: per-token importance scores that sum to describe relative contribution
Related Pages
- Principle:Pytorch_Serve_Model_Explainability - The principle of making model predictions interpretable through attribution methods
- Implementation:Pytorch_Serve_TransformersSeqClassifierHandler - The handler class containing the
get_insights()method - Implementation:Pytorch_Serve_Transformer_Handler_Config - Configuration that enables Captum explanations