Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TEGemmaDecoderLayer

From Leeroopedia


Overview

TE-accelerated replacement for HuggingFace's GemmaDecoderLayer using TransformerEngine's TransformerLayer.

Description

TEGemmaDecoderLayer subclasses te.pytorch.TransformerLayer to serve as a drop-in replacement for HuggingFace's GemmaDecoderLayer within the Gemma model. It maps GemmaConfig fields to TE TransformerLayer constructor parameters with Gemma-specific settings: zero_centered_gamma=True, activation="geglu", bias=False, normalization="RMSNorm", and attn_input_format="bshd".

The class includes layer index tracking (layer_idx) which is incremented by 1 when passed to TE (since TE layer numbering starts from 1, not 0). The forward() method filters out HuggingFace-specific keyword arguments that are not applicable to TE's TransformerLayer and passes through rotary_pos_emb for RoPE support, as well as inference_params for KV cache management.

This is a Wrapper Doc.

Source

docs/examples/te_gemma/te_gemma.py, class TEGemmaDecoderLayer at lines 139-191.

Signature

class TEGemmaDecoderLayer(te.pytorch.TransformerLayer):
    def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs):

        self.gemma_config = config

        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=config.fuse_qkv_params,
            normalization="RMSNorm",
            activation="geglu",
            attn_input_format="bshd",
            num_gqa_groups=config.num_key_value_heads,
            kv_channels=self.gemma_config.head_dim,
            layer_number=(
                layer_idx + 1
            ),  # Layer numbers in TE starts from 1, not 0 like in the HF.
            zero_centered_gamma=True,
        )

    def forward(self, *args, **kwargs):
        # Filters out HF-specific kwargs:
        #   position_ids, past_key_value, output_attentions, use_cache, cache_position
        # Extracts rope_emb from kwargs
        # Delegates to TransformerLayer.forward with rotary_pos_emb=rope_emb
        # Returns tuple (output,) for HF compatibility
        return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),)

I/O

Input:

  • config: GemmaConfig -- HuggingFace Gemma configuration object containing model hyperparameters
  • layer_idx: int -- Zero-based layer index (converted to 1-based for TE internally)
  • *args, **kwargs: Additional positional and keyword arguments for HF compatibility

Forward Input:

  • hidden_states: torch.Tensor -- Input hidden states tensor
  • rope_emb: torch.Tensor (optional, via kwargs) -- Rotary position embedding tensor
  • inference_params: InferenceParams (optional, via kwargs) -- KV cache manager for inference
  • attention_mask: torch.Tensor (optional, via kwargs) -- Attention mask tensor
  • self_attn_mask_type: str (optional, via kwargs) -- Mask type, e.g. "padding_causal"

Output:

  • tuple containing a single torch.Tensor -- The layer output wrapped in a tuple for HF compatibility

Key Parameters

Parameter Type Description
hidden_size int Model hidden dimension from config.hidden_size
ffn_hidden_size int FFN intermediate dimension from config.intermediate_size
num_attention_heads int Number of query attention heads
num_gqa_groups int Number of KV heads for grouped-query attention
kv_channels int Per-head dimension from config.head_dim
zero_centered_gamma bool True; RMSNorm weight initialized to 0, effective gamma = 1 + weight
activation str "geglu"; Gated GELU activation for the FFN
fuse_qkv_params bool Whether to fuse QKV projections into a single parameter

Notes

  • The forward method filters out HF-specific keyword arguments (position_ids, past_key_value, output_attentions, use_cache, cache_position) that are not used by TE's TransformerLayer.
  • The output is wrapped in a tuple (output,) to match the return signature expected by HuggingFace's GemmaModel layer iteration.
  • Layer numbering is adjusted from 0-based (HF convention) to 1-based (TE convention) via layer_number = layer_idx + 1.
  • Both hidden_dropout and attention_dropout are set to 0, consistent with Gemma's architecture.

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment