Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine TEGemmaForCausalLM

From Leeroopedia
Revision as of 16:00, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_TEGemmaForCausalLM.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

Complete TE-accelerated Gemma model with FP8 training and generation support.

Description

TEGemmaForCausalLM extends HuggingFace's GemmaForCausalLM with TransformerEngine decoder layers, a custom generate() method supporting InferenceParams-based KV caching and optional CUDA graphs, a forward() method for training and calibration, and hardware-adaptive FP8 recipe selection.

The class monkey-patches GemmaDecoderLayer with TEGemmaDecoderLayer during model construction via the replace_decoder() context manager. After construction, it creates two wrapper objects for inference:

  • _model_context_phase (GemmaModelWrapper): Handles full-sequence prefill, applying "padding_causal" attention masking
  • _model_generation_phase (GemmaGenerationWrapper): Handles single-token decode, applying "padding" attention masking and argmax token selection

A subclass TEGemmaForCausalLMCudaGraphs further extends this class by pre-allocating static buffers and capturing the two phases as CUDA graph callables via te.pytorch.make_graphed_callables().

This is a Wrapper Doc.

Source

docs/examples/te_gemma/te_gemma.py, class TEGemmaForCausalLM at lines 305-558.

Signature

class TEGemmaForCausalLM(GemmaForCausalLM):
    def __init__(self, config: GemmaConfig):
        dtype = torch.bfloat16
        with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer):
            super().__init__(config)

        self.config = config
        self.to(dtype).cuda()
        self.hidden_size = config.hidden_size

        self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head)
        self._model_generation_phase = GemmaGenerationWrapper(
            lm_head=self.lm_head,
            model=self.model,
            dtype=dtype,
        )

        if self.config.fp8:
            self.fp8_recipe = get_default_fp8_recipe()

        self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)(
            max_seq_len=self.config.max_position_embeddings
        ).cuda()

    @torch.no_grad()
    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        pad_token_id: int = 0,
        max_new_tokens: int = 0,
        *args,
        **kwargs,
    ):
        """
        Autoregressive generation with KV cache.
        1. Shifts left-padded inputs to right-padded
        2. Creates InferenceParams for KV caching
        3. Runs context/prefill phase
        4. Loops max_new_tokens times for generation phase
        5. Returns [batch, input_len + max_new_tokens + 1] tensor
        """
        ...

    def forward(self, *args, **kwargs):
        """
        Training/calibration forward pass.
        Sets inference_params=None, applies padding_causal mask,
        runs through model layers and lm_head.
        """
        ...

I/O

Constructor Input:

  • config: GemmaConfig -- HuggingFace Gemma configuration, extended with TE-specific fields:
    • config.fp8: bool -- Whether to enable FP8
    • config.fuse_qkv_params: bool -- Whether to fuse QKV projections
    • config.is_paged: bool -- Whether to use paged KV cache
    • config.max_seq_length: int -- Maximum sequence length for inference
    • config.generation_cuda_graphs: bool -- Whether CUDA graphs are enabled
    • config.cuda_graphs_static_batch_size: int -- Static batch size for CUDA graphs
    • config.cuda_graphs_static_max_context_len: int -- Static max context length for CUDA graphs

Constructor Output:

  • TEGemmaForCausalLM instance on CUDA with bfloat16 precision

generate() Input:

  • input_ids: torch.Tensor -- Input token IDs of shape [batch, seq_len], left-padded with pad_token_id
  • pad_token_id: int -- Padding token ID (default: 0)
  • max_new_tokens: int -- Number of new tokens to generate (default: 0)

generate() Output:

  • torch.Tensor -- Generated token tensor of shape [batch, input_len + max_new_tokens + 1]

forward() Input:

  • input_ids: torch.Tensor -- Input token IDs (via kwargs)

forward() Output:

  • torch.Tensor -- Logits tensor of shape [batch, seq_len, vocab_size]

Key Components

Component Type Purpose
_model_context_phase GemmaModelWrapper Runs prefill: embeds input, iterates all layers with "padding_causal" mask, applies final norm and lm_head
_model_generation_phase GemmaGenerationWrapper Runs decode: processes single token, applies "padding" mask, selects next token via argmax
te_rope_emb torch.Tensor Pre-computed rotary position embeddings shared across all layers
fp8_recipe DelayedScaling FP8 recipe selected based on hardware capabilities (only when config.fp8=True)
inference_params InferenceParams KV cache manager created per generate() call (or static for CUDA graphs)

Generation Flow

The generate() method follows this sequence:

  1. Padding shift: Left-padded HF-style inputs are shifted to right-padded format for TE attention mask compatibility
  2. InferenceParams creation: KV cache is allocated with batch size, max sequence length, KV head count, and head dimensions
  3. Context/prefill phase: Full input sequence is processed, K/V cached, last-token logits extracted, first generated token selected
  4. Generation loop: For each of max_new_tokens iterations:
    • Single-token forward pass through _model_generation_phase
    • New K/V appended to cache via inference_params.pre_step() and internal step()
    • Next token selected via argmax
  5. Result assembly: Input tokens concatenated with all generated tokens

Both torch.amp.autocast (for BF16) and te.pytorch.autocast (for FP8) are applied during generation.

Notes

  • The model is moved to CUDA and cast to bfloat16 in the constructor.
  • _model_context_phase and _model_generation_phase share the same underlying model layers and lm_head -- they are wrappers, not separate models.
  • The forward() method sets inference_params=None, making it suitable for training and FP8 calibration but not generation.
  • The _padding_to_end() static method handles the conversion from HF-style left-padding to TE-compatible right-padding.

Related

Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment