Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Captum Explanations

From Leeroopedia
Revision as of 13:45, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pytorch_Serve_Captum_Explanations.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Field Value
Page Type Implementation
Title Captum Explanations
Type API Doc
Short Description Implementation of model explainability using Captum Layer Integrated Gradients to compute per-token importance scores for HuggingFace Transformer predictions
Domains NLP, Explainability
Source examples/Huggingface_Transformers/Transformer_handler_generalized.py:L357-513
Knowledge Sources TorchServe
Workflow HuggingFace_Transformer_Serving
Last Updated 2026-02-13 00:00 GMT

Overview

The Captum Explanations implementation provides model interpretability for HuggingFace Transformer predictions served through TorchServe. It uses Captum's LayerIntegratedGradients algorithm applied to the model's embedding layer to compute per-token attribution scores. The implementation consists of the get_insights() method on the handler class and four supporting module-level functions that handle input construction, forward pass wrapping, attribution summarization, and token extraction.

Description

The implementation computes how much each input token contributes to the model's prediction for a specified target class. It supports both classification tasks (sequence and token classification) and question answering, producing different output formats for each.

Usage

Explanations are triggered through TorchServe's explanation endpoint:

# Sequence classification explanation
curl -X POST http://localhost:8080/explanations/bert_model \
  -H "Content-Type: text/plain" \
  -d '{"text": "This movie was fantastic", "target": 1}'

# Question answering explanation
curl -X POST http://localhost:8080/explanations/bert_qa \
  -H "Content-Type: text/plain" \
  -d '{"question": "Who founded Tesla?", "context": "Tesla was founded by Elon Musk.", "target": 0}'

Requires captum_explanation: true and embedding_name set in model-config.yaml.

Code Reference

Source Location

Field Value
Repository pytorch/serve
File examples/Huggingface_Transformers/Transformer_handler_generalized.py
Lines L357-513

Method and Function Signatures

get_insights(self, input_batch, text, target) (L357-427)

def get_insights(self, input_batch, text, target):
    """This function initialize and calls the layer integrated gradient to get word importance
    of the input text if captum explanation has been selected through setup_config
    Args:
        input_batch (int): Batches of tokens IDs of text
        text (str): The Text specified in the input request
        target (int): The Target can be set to any acceptable label under the user's discretion.
    Returns:
        (list): Returns a list of importances and words.
    """

Main method on TransformersSeqClassifierHandler that orchestrates the explanation computation:

  1. Initializes LayerIntegratedGradients with captum_sequence_forward and the model's embedding layer (accessed via getattr(self.model, self.setup_config["embedding_name"]).embeddings)
  2. Parses input text to extract text content and target class
  3. Constructs input IDs, reference IDs, and attention mask via construct_input_ref()
  4. Extracts word tokens via get_word_token()
  5. For sequence/token classification: runs lig.attribute() once for the target class
  6. For question answering: runs lig.attribute() twice (position=0 for start, position=1 for end)
  7. Summarizes attributions and returns response dictionary

construct_input_ref(text, tokenizer, device, mode) (L430-466)

def construct_input_ref(text, tokenizer, device, mode):
    """For a given text, this function creates token id, reference id and
    attention mask based on encode which is faster for captum insights
    Args:
        text (str): The text specified in the input request
        tokenizer (AutoTokenizer Class Object): To word tokenize the input text
        device (cpu or gpu): Type of the Environment the server runs on.
    Returns:
        input_id(Tensor): It attributes to the tensor of the input tokenized words
        ref_input_ids(Tensor): Ref Input IDs are used as baseline for the attributions
        attention mask() :  The attention mask is a binary tensor indicating the position
         of the padded indices so that the model does not attend to them.
    """

Constructs three tensors required for integrated gradients:

  • input_ids: [CLS] + text_token_ids + [SEP] as a tensor
  • ref_input_ids: [CLS] + [PAD]*len(text_tokens) + [SEP] as a tensor (the baseline)
  • attention_mask: All ones, same shape as input_ids

For question answering mode, encodes the question-context pair first, then falls through to standard processing.

captum_sequence_forward(inputs, attention_mask, position, model) (L469-485)

def captum_sequence_forward(inputs, attention_mask=None, position=0, model=None):
    """This function is used to get the predictions from the model and this function
    can be used independent of the type of the BERT Task.
    Args:
        inputs (list): Input for Predictions
        attention_mask (list, optional): The attention mask is a binary tensor indicating the position
         of the padded indices so that the model does not attend to them, it defaults to None.
        position (int, optional): Position depends on the BERT Task.
        model ([type], optional): Name of the model, it defaults to None.
    Returns:
        list: Prediction Outcome
    """

Wrapper function that Captum calls during attribution computation:

  1. Sets model to eval mode and zeros gradients
  2. Runs model(inputs, attention_mask=attention_mask)
  3. Returns pred[position] where position selects the output (0 = logits or start_logits, 1 = end_logits)

summarize_attributions(attributions) (L488-497)

def summarize_attributions(attributions):
    """Summarizes the attribution across multiple runs
    Args:
        attributions ([list): attributions from the Layer Integrated Gradients
    Returns:
        list : Returns the attributions after normalizing them.
    """

Reduces the per-dimension attributions to per-token scores:

  1. Sums across the embedding dimension: attributions.sum(dim=-1).squeeze(0)
  2. Normalizes by L2 norm: attributions / torch.norm(attributions)

get_word_token(input_ids, tokenizer) (L500-513)

def get_word_token(input_ids, tokenizer):
    """constructs word tokens from token id using the BERT's
    Auto Tokenizer
    Args:
        input_ids (list): Input IDs from construct_input_ref method
        tokenizer (class): The Auto Tokenizer Pre-Trained model object
    Returns:
        (list): Returns the word tokens
    """

Converts token IDs back to readable string tokens:

  1. Extracts indices from the first batch element: input_ids[0].detach().tolist()
  2. Converts IDs to tokens: tokenizer.convert_ids_to_tokens(indices)
  3. Cleans BPE artifacts: removes the "G" (Unicode space character from GPT-style tokenizers)

Import

import ast
import torch
from captum.attr import LayerIntegratedGradients
from transformers import AutoTokenizer

I/O Contract

Input

Function Input Type Description
get_insights input_batch tuple(Tensor, Tensor) Token IDs and attention mask from preprocess
get_insights text str Raw input text, formatted as {"text": "...", "target": N}
get_insights target int Target class index for attribution
construct_input_ref text str The text string to tokenize
construct_input_ref tokenizer AutoTokenizer The loaded tokenizer object
construct_input_ref device torch.device CPU or GPU device
construct_input_ref mode str NLP task mode string
captum_sequence_forward inputs Tensor Embedding-level inputs
captum_sequence_forward attention_mask Tensor Attention mask tensor
captum_sequence_forward position int Output position index (0 or 1)
captum_sequence_forward model nn.Module The Transformer model
summarize_attributions attributions Tensor Raw attribution tensor from LIG
get_word_token input_ids Tensor Token ID tensor
get_word_token tokenizer AutoTokenizer The loaded tokenizer object

Output

Sequence Classification / Token Classification:

[{
    "words": ["[CLS]", "this", "movie", "was", "great", "[SEP]"],
    "importances": [0.01, 0.05, 0.15, 0.02, 0.77, 0.00],
    "delta": 0.0012
}]

Question Answering:

[{
    "words": ["[CLS]", "who", "founded", "tesla", "?", "[SEP]", "tesla", "was", "founded", "by", "elon", "musk", ".", "[SEP]"],
    "importances_answer_start": [0.01, 0.05, 0.08, 0.03, 0.02, 0.00, 0.03, 0.02, 0.05, 0.10, 0.45, 0.15, 0.01, 0.00],
    "importances_answer_end": [0.01, 0.03, 0.05, 0.02, 0.01, 0.00, 0.02, 0.01, 0.03, 0.08, 0.12, 0.55, 0.07, 0.00],
    "delta_start": 0.0015,
    "delta_end": 0.0023
}]
Output Field Type Description
words list of str The input tokens including special tokens
importances list of float Normalized per-token attribution scores (seq/token classification)
importances_answer_start list of float Attributions for answer start position (QA)
importances_answer_end list of float Attributions for answer end position (QA)
delta float Convergence delta measuring approximation quality (seq/token classification)
delta_start float Convergence delta for answer start (QA)
delta_end float Convergence delta for answer end (QA)

Usage Examples

Example 1: Sequence Classification Explanation

# Input to get_insights:
# text = '{"text": "Bloomberg has decided to publish a new report on global warming.", "target": 1}'

# Internal flow:
embedding_layer = getattr(self.model, "bert")  # from setup_config["embedding_name"]
embeddings = embedding_layer.embeddings
self.lig = LayerIntegratedGradients(captum_sequence_forward, embeddings)

input_ids, ref_input_ids, attention_mask = construct_input_ref(
    text="Bloomberg has decided to publish a new report on global warming.",
    tokenizer=self.tokenizer,
    device=self.device,
    mode="sequence_classification"
)
# input_ids: [101, 23280, 2038, 2787, ..., 102]  (CLS + tokens + SEP)
# ref_input_ids: [101, 0, 0, 0, ..., 102]  (CLS + PADs + SEP)

attributions, delta = self.lig.attribute(
    inputs=input_ids,
    baselines=ref_input_ids,
    target=1,
    additional_forward_args=(attention_mask, 0, self.model),
    return_convergence_delta=True,
)
attributions_sum = summarize_attributions(attributions)

# Output: [{"words": ["[CLS]", "bloomberg", "has", ...], "importances": [...], "delta": 0.001}]

Example 2: Question Answering Explanation

# Input to get_insights:
# text = '{"question": "Who is CEO?", "context": "Elon Musk is CEO of Tesla.", "target": 0}'

# Runs LIG twice:
# 1. position=0 for answer start attributions
attributions_start, delta_start = self.lig.attribute(
    inputs=input_ids, baselines=ref_input_ids, target=0,
    additional_forward_args=(attention_mask, 0, self.model),
    return_convergence_delta=True,
)
# 2. position=1 for answer end attributions
attributions_end, delta_end = self.lig.attribute(
    inputs=input_ids, baselines=ref_input_ids, target=0,
    additional_forward_args=(attention_mask, 1, self.model),
    return_convergence_delta=True,
)

# Output: [{"words": [...], "importances_answer_start": [...],
#           "importances_answer_end": [...], "delta_start": ..., "delta_end": ...}]

Example 3: Using construct_input_ref

input_ids, ref_input_ids, attention_mask = construct_input_ref(
    text="This is a test sentence.",
    tokenizer=tokenizer,
    device=torch.device("cpu"),
    mode="sequence_classification"
)
# input_ids shape: (1, num_tokens + 2)  -- includes CLS and SEP
# ref_input_ids shape: (1, num_tokens + 2)  -- PAD tokens in place of content
# attention_mask shape: (1, num_tokens + 2)  -- all ones

Example 4: Using summarize_attributions

# Raw attributions shape: (1, seq_len, embedding_dim)
# e.g., (1, 10, 768) for a 10-token input with BERT

attributions_sum = summarize_attributions(attributions)
# Step 1: sum(dim=-1) -> (1, 10) -> squeeze(0) -> (10,)
# Step 2: normalize by L2 norm -> (10,) with unit norm
# Result: per-token importance scores that sum to describe relative contribution

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment