Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Captum Explanations

From Leeroopedia
Field Value
Page Type Implementation
Title Captum Explanations
Type API Doc
Short Description Implementation of model explainability using Captum Layer Integrated Gradients to compute per-token importance scores for HuggingFace Transformer predictions
Domains NLP, Explainability
Source examples/Huggingface_Transformers/Transformer_handler_generalized.py:L357-513
Knowledge Sources TorchServe
Workflow HuggingFace_Transformer_Serving
Last Updated 2026-02-13 00:00 GMT

Overview

The Captum Explanations implementation provides model interpretability for HuggingFace Transformer predictions served through TorchServe. It uses Captum's LayerIntegratedGradients algorithm applied to the model's embedding layer to compute per-token attribution scores. The implementation consists of the get_insights() method on the handler class and four supporting module-level functions that handle input construction, forward pass wrapping, attribution summarization, and token extraction.

Description

The implementation computes how much each input token contributes to the model's prediction for a specified target class. It supports both classification tasks (sequence and token classification) and question answering, producing different output formats for each.

Usage

Explanations are triggered through TorchServe's explanation endpoint:

# Sequence classification explanation
curl -X POST http://localhost:8080/explanations/bert_model \
  -H "Content-Type: text/plain" \
  -d '{"text": "This movie was fantastic", "target": 1}'

# Question answering explanation
curl -X POST http://localhost:8080/explanations/bert_qa \
  -H "Content-Type: text/plain" \
  -d '{"question": "Who founded Tesla?", "context": "Tesla was founded by Elon Musk.", "target": 0}'

Requires captum_explanation: true and embedding_name set in model-config.yaml.

Code Reference

Source Location

Field Value
Repository pytorch/serve
File examples/Huggingface_Transformers/Transformer_handler_generalized.py
Lines L357-513

Method and Function Signatures

get_insights(self, input_batch, text, target) (L357-427)

def get_insights(self, input_batch, text, target):
    """This function initialize and calls the layer integrated gradient to get word importance
    of the input text if captum explanation has been selected through setup_config
    Args:
        input_batch (int): Batches of tokens IDs of text
        text (str): The Text specified in the input request
        target (int): The Target can be set to any acceptable label under the user's discretion.
    Returns:
        (list): Returns a list of importances and words.
    """

Main method on TransformersSeqClassifierHandler that orchestrates the explanation computation:

  1. Initializes LayerIntegratedGradients with captum_sequence_forward and the model's embedding layer (accessed via getattr(self.model, self.setup_config["embedding_name"]).embeddings)
  2. Parses input text to extract text content and target class
  3. Constructs input IDs, reference IDs, and attention mask via construct_input_ref()
  4. Extracts word tokens via get_word_token()
  5. For sequence/token classification: runs lig.attribute() once for the target class
  6. For question answering: runs lig.attribute() twice (position=0 for start, position=1 for end)
  7. Summarizes attributions and returns response dictionary

construct_input_ref(text, tokenizer, device, mode) (L430-466)

def construct_input_ref(text, tokenizer, device, mode):
    """For a given text, this function creates token id, reference id and
    attention mask based on encode which is faster for captum insights
    Args:
        text (str): The text specified in the input request
        tokenizer (AutoTokenizer Class Object): To word tokenize the input text
        device (cpu or gpu): Type of the Environment the server runs on.
    Returns:
        input_id(Tensor): It attributes to the tensor of the input tokenized words
        ref_input_ids(Tensor): Ref Input IDs are used as baseline for the attributions
        attention mask() :  The attention mask is a binary tensor indicating the position
         of the padded indices so that the model does not attend to them.
    """

Constructs three tensors required for integrated gradients:

  • input_ids: [CLS] + text_token_ids + [SEP] as a tensor
  • ref_input_ids: [CLS] + [PAD]*len(text_tokens) + [SEP] as a tensor (the baseline)
  • attention_mask: All ones, same shape as input_ids

For question answering mode, encodes the question-context pair first, then falls through to standard processing.

captum_sequence_forward(inputs, attention_mask, position, model) (L469-485)

def captum_sequence_forward(inputs, attention_mask=None, position=0, model=None):
    """This function is used to get the predictions from the model and this function
    can be used independent of the type of the BERT Task.
    Args:
        inputs (list): Input for Predictions
        attention_mask (list, optional): The attention mask is a binary tensor indicating the position
         of the padded indices so that the model does not attend to them, it defaults to None.
        position (int, optional): Position depends on the BERT Task.
        model ([type], optional): Name of the model, it defaults to None.
    Returns:
        list: Prediction Outcome
    """

Wrapper function that Captum calls during attribution computation:

  1. Sets model to eval mode and zeros gradients
  2. Runs model(inputs, attention_mask=attention_mask)
  3. Returns pred[position] where position selects the output (0 = logits or start_logits, 1 = end_logits)

summarize_attributions(attributions) (L488-497)

def summarize_attributions(attributions):
    """Summarizes the attribution across multiple runs
    Args:
        attributions ([list): attributions from the Layer Integrated Gradients
    Returns:
        list : Returns the attributions after normalizing them.
    """

Reduces the per-dimension attributions to per-token scores:

  1. Sums across the embedding dimension: attributions.sum(dim=-1).squeeze(0)
  2. Normalizes by L2 norm: attributions / torch.norm(attributions)

get_word_token(input_ids, tokenizer) (L500-513)

def get_word_token(input_ids, tokenizer):
    """constructs word tokens from token id using the BERT's
    Auto Tokenizer
    Args:
        input_ids (list): Input IDs from construct_input_ref method
        tokenizer (class): The Auto Tokenizer Pre-Trained model object
    Returns:
        (list): Returns the word tokens
    """

Converts token IDs back to readable string tokens:

  1. Extracts indices from the first batch element: input_ids[0].detach().tolist()
  2. Converts IDs to tokens: tokenizer.convert_ids_to_tokens(indices)
  3. Cleans BPE artifacts: removes the "G" (Unicode space character from GPT-style tokenizers)

Import

import ast
import torch
from captum.attr import LayerIntegratedGradients
from transformers import AutoTokenizer

I/O Contract

Input

Function Input Type Description
get_insights input_batch tuple(Tensor, Tensor) Token IDs and attention mask from preprocess
get_insights text str Raw input text, formatted as {"text": "...", "target": N}
get_insights target int Target class index for attribution
construct_input_ref text str The text string to tokenize
construct_input_ref tokenizer AutoTokenizer The loaded tokenizer object
construct_input_ref device torch.device CPU or GPU device
construct_input_ref mode str NLP task mode string
captum_sequence_forward inputs Tensor Embedding-level inputs
captum_sequence_forward attention_mask Tensor Attention mask tensor
captum_sequence_forward position int Output position index (0 or 1)
captum_sequence_forward model nn.Module The Transformer model
summarize_attributions attributions Tensor Raw attribution tensor from LIG
get_word_token input_ids Tensor Token ID tensor
get_word_token tokenizer AutoTokenizer The loaded tokenizer object

Output

Sequence Classification / Token Classification:

[{
    "words": ["[CLS]", "this", "movie", "was", "great", "[SEP]"],
    "importances": [0.01, 0.05, 0.15, 0.02, 0.77, 0.00],
    "delta": 0.0012
}]

Question Answering:

[{
    "words": ["[CLS]", "who", "founded", "tesla", "?", "[SEP]", "tesla", "was", "founded", "by", "elon", "musk", ".", "[SEP]"],
    "importances_answer_start": [0.01, 0.05, 0.08, 0.03, 0.02, 0.00, 0.03, 0.02, 0.05, 0.10, 0.45, 0.15, 0.01, 0.00],
    "importances_answer_end": [0.01, 0.03, 0.05, 0.02, 0.01, 0.00, 0.02, 0.01, 0.03, 0.08, 0.12, 0.55, 0.07, 0.00],
    "delta_start": 0.0015,
    "delta_end": 0.0023
}]
Output Field Type Description
words list of str The input tokens including special tokens
importances list of float Normalized per-token attribution scores (seq/token classification)
importances_answer_start list of float Attributions for answer start position (QA)
importances_answer_end list of float Attributions for answer end position (QA)
delta float Convergence delta measuring approximation quality (seq/token classification)
delta_start float Convergence delta for answer start (QA)
delta_end float Convergence delta for answer end (QA)

Usage Examples

Example 1: Sequence Classification Explanation

# Input to get_insights:
# text = '{"text": "Bloomberg has decided to publish a new report on global warming.", "target": 1}'

# Internal flow:
embedding_layer = getattr(self.model, "bert")  # from setup_config["embedding_name"]
embeddings = embedding_layer.embeddings
self.lig = LayerIntegratedGradients(captum_sequence_forward, embeddings)

input_ids, ref_input_ids, attention_mask = construct_input_ref(
    text="Bloomberg has decided to publish a new report on global warming.",
    tokenizer=self.tokenizer,
    device=self.device,
    mode="sequence_classification"
)
# input_ids: [101, 23280, 2038, 2787, ..., 102]  (CLS + tokens + SEP)
# ref_input_ids: [101, 0, 0, 0, ..., 102]  (CLS + PADs + SEP)

attributions, delta = self.lig.attribute(
    inputs=input_ids,
    baselines=ref_input_ids,
    target=1,
    additional_forward_args=(attention_mask, 0, self.model),
    return_convergence_delta=True,
)
attributions_sum = summarize_attributions(attributions)

# Output: [{"words": ["[CLS]", "bloomberg", "has", ...], "importances": [...], "delta": 0.001}]

Example 2: Question Answering Explanation

# Input to get_insights:
# text = '{"question": "Who is CEO?", "context": "Elon Musk is CEO of Tesla.", "target": 0}'

# Runs LIG twice:
# 1. position=0 for answer start attributions
attributions_start, delta_start = self.lig.attribute(
    inputs=input_ids, baselines=ref_input_ids, target=0,
    additional_forward_args=(attention_mask, 0, self.model),
    return_convergence_delta=True,
)
# 2. position=1 for answer end attributions
attributions_end, delta_end = self.lig.attribute(
    inputs=input_ids, baselines=ref_input_ids, target=0,
    additional_forward_args=(attention_mask, 1, self.model),
    return_convergence_delta=True,
)

# Output: [{"words": [...], "importances_answer_start": [...],
#           "importances_answer_end": [...], "delta_start": ..., "delta_end": ...}]

Example 3: Using construct_input_ref

input_ids, ref_input_ids, attention_mask = construct_input_ref(
    text="This is a test sentence.",
    tokenizer=tokenizer,
    device=torch.device("cpu"),
    mode="sequence_classification"
)
# input_ids shape: (1, num_tokens + 2)  -- includes CLS and SEP
# ref_input_ids shape: (1, num_tokens + 2)  -- PAD tokens in place of content
# attention_mask shape: (1, num_tokens + 2)  -- all ones

Example 4: Using summarize_attributions

# Raw attributions shape: (1, seq_len, embedding_dim)
# e.g., (1, 10, 768) for a 10-token input with BERT

attributions_sum = summarize_attributions(attributions)
# Step 1: sum(dim=-1) -> (1, 10) -> squeeze(0) -> (10,)
# Step 2: normalize by L2 norm -> (10,) with unit norm
# Result: per-token importance scores that sum to describe relative contribution

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment