Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Replace Decoder Context Manager

From Leeroopedia


Overview

Context manager that monkey-patches the HuggingFace decoder layer class for TransformerEngine substitution.

Doc Type

Pattern Doc -- This function implements the monkey-patching pattern for transparent layer class substitution at model construction time.

Description

replace_decoder(te_decoder_cls) is a context manager (decorated with @contextmanager) that temporarily replaces the LlamaDecoderLayer class reference in the transformers.models.llama.modeling_llama module namespace with the provided TE decoder class.

The function follows this sequence:

  1. Save the original LlamaDecoderLayer class reference.
  2. Replace the module-level reference with the provided te_decoder_cls.
  3. Yield control to the caller (within the with block).
  4. Restore the original class reference in the finally block, ensuring restoration even if an exception occurs.

This makes it safe to use in production code -- the module state is always cleaned up, and the substitution is scoped precisely to the with block.

Source

  • File: docs/examples/te_llama/te_llama.py
  • Function: replace_decoder
  • Lines: L27-37

Signature

@contextmanager
def replace_decoder(te_decoder_cls):
    """
    Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
    """
    original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
    transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
    try:
        yield
    finally:
        transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls

I/O

Direction Name Type Description
Input te_decoder_cls type (class) The TE decoder layer class to substitute for LlamaDecoderLayer. Must accept the same constructor signature as LlamaDecoderLayer (i.e., (config, *args, **kwargs)).
Output context manager generator-based context manager Yields nothing. The substitution is active for the duration of the with block.

Behavior

The context manager modifies and restores a single module-level attribute:

Phase Action Module State
Entry Saves original class, assigns te_decoder_cls transformers.models.llama.modeling_llama.LlamaDecoderLayer points to te_decoder_cls
During Caller constructs model Any LlamaDecoderLayer() calls instantiate te_decoder_cls
Exit (normal) Restores original class transformers.models.llama.modeling_llama.LlamaDecoderLayer points to original
Exit (exception) Restores original class via finally transformers.models.llama.modeling_llama.LlamaDecoderLayer points to original

Example Usage

from te_llama import TELlamaDecoderLayer, replace_decoder
from transformers import LlamaForCausalLM, LlamaConfig

config = LlamaConfig(...)

# Within the context, LlamaForCausalLM will use TELlamaDecoderLayer
with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
    model = LlamaForCausalLM(config)

# After the context, the original LlamaDecoderLayer is restored
# but the model retains its TELlamaDecoderLayer instances

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment