Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TELlamaDecoderLayer

From Leeroopedia


Overview

TE-accelerated replacement for HuggingFace's LlamaDecoderLayer using TransformerEngine's TransformerLayer.

Doc Type

Wrapper Doc -- This class wraps TE's TransformerLayer to provide an API-compatible replacement for HuggingFace's LlamaDecoderLayer.

Description

TELlamaDecoderLayer subclasses te.pytorch.TransformerLayer and maps LlamaConfig parameters to TE's TransformerLayer constructor arguments. The wrapper preserves the HuggingFace decoder layer interface while replacing the computational implementation with TE's fused CUDA kernels.

Key configuration choices:

  • RMSNorm: Uses normalization="RMSNorm" to match LLaMA's normalization scheme.
  • SwiGLU: Uses activation="swiglu" to match LLaMA's MLP activation function.
  • No Bias: Sets bias=False since LLaMA does not use bias terms.
  • BSHD Format: Uses attn_input_format="bshd" (Batch, Sequence, Head, Dimension) tensor layout.
  • Unfused QKV: Sets fuse_qkv_params=False to keep Q, K, V as separate weight parameters, which simplifies weight loading from HF checkpoints.
  • GQA Support: Maps config.num_key_value_heads to num_gqa_groups for Grouped Query Attention.
  • Pre-computed RoPE: Computes rotary position embeddings once during __init__ and reuses them in every forward pass.

The forward method handles format adaptation:

  • Accepts hidden_states and attention_mask from HF's layer interface
  • Handles the case where hidden_states is a tuple (compatibility with older HF versions)
  • Injects pre-computed RoPE embeddings via rotary_pos_emb
  • Returns output directly as a tensor (compatible with HuggingFace transformers >= 4.57)

Source

  • File: docs/examples/te_llama/te_llama.py
  • Class: TELlamaDecoderLayer
  • Lines: L40-84

Signature

class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    """
    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
    similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.

    Args:
        config: LlamaConfig
        args: positional args (for compatibility with `LlamaDecoderLayer`)
        kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
    """

    def __init__(self, config, *args, **kwargs):
        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            num_gqa_groups=config.num_key_value_heads,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            bias=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
            fuse_qkv_params=False,
        )
        te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

    def forward(self, hidden_states, *args, attention_mask, **kwargs):
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
        if isinstance(hidden_states, tuple):
            hidden_states = hidden_states[0]

        return super().forward(
            hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
        )

I/O

Direction Name Type Description
Input hidden_states torch.Tensor or tuple Hidden state tensor from the previous layer. Shape: (batch_size, seq_len, hidden_size). May be a tuple for compatibility with older HF versions.
Input attention_mask torch.Tensor or None Attention mask tensor. Passed through to TE's TransformerLayer.forward().
Input *args, **kwargs various Additional positional and keyword arguments accepted for HF compatibility but not forwarded to TE.
Output return value torch.Tensor Transformed hidden states. Shape: (batch_size, seq_len, hidden_size). Returned directly as a tensor (HF transformers >= 4.57 compatibility).

Parameter Mapping

HF LlamaConfig Field TE TransformerLayer Parameter Purpose
config.hidden_size hidden_size Model hidden dimension
config.intermediate_size ffn_hidden_size Feed-forward network inner dimension
config.num_attention_heads num_attention_heads Number of attention heads
config.num_key_value_heads num_gqa_groups Number of key/value heads for GQA
config.rms_norm_eps layernorm_epsilon RMSNorm epsilon for numerical stability
config.hidden_size // config.num_attention_heads RoPE dimension Head dimension for rotary embedding computation
config.max_position_embeddings RoPE max sequence length Maximum sequence length for pre-computed RoPE

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment