Implementation:NVIDIA TransformerEngine TELlamaDecoderLayer
Appearance
Overview
TE-accelerated replacement for HuggingFace's LlamaDecoderLayer using TransformerEngine's TransformerLayer.
Doc Type
Wrapper Doc -- This class wraps TE's TransformerLayer to provide an API-compatible replacement for HuggingFace's LlamaDecoderLayer.
Description
TELlamaDecoderLayer subclasses te.pytorch.TransformerLayer and maps LlamaConfig parameters to TE's TransformerLayer constructor arguments. The wrapper preserves the HuggingFace decoder layer interface while replacing the computational implementation with TE's fused CUDA kernels.
Key configuration choices:
- RMSNorm: Uses
normalization="RMSNorm"to match LLaMA's normalization scheme. - SwiGLU: Uses
activation="swiglu"to match LLaMA's MLP activation function. - No Bias: Sets
bias=Falsesince LLaMA does not use bias terms. - BSHD Format: Uses
attn_input_format="bshd"(Batch, Sequence, Head, Dimension) tensor layout. - Unfused QKV: Sets
fuse_qkv_params=Falseto keep Q, K, V as separate weight parameters, which simplifies weight loading from HF checkpoints. - GQA Support: Maps
config.num_key_value_headstonum_gqa_groupsfor Grouped Query Attention. - Pre-computed RoPE: Computes rotary position embeddings once during
__init__and reuses them in every forward pass.
The forward method handles format adaptation:
- Accepts
hidden_statesandattention_maskfrom HF's layer interface - Handles the case where
hidden_statesis a tuple (compatibility with older HF versions) - Injects pre-computed RoPE embeddings via
rotary_pos_emb - Returns output directly as a tensor (compatible with HuggingFace transformers >= 4.57)
Source
- File:
docs/examples/te_llama/te_llama.py - Class:
TELlamaDecoderLayer - Lines: L40-84
Signature
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
"""
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
Args:
config: LlamaConfig
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
num_gqa_groups=config.num_key_value_heads,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
bias=False,
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
fuse_qkv_params=False,
)
te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
return super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
)
I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | hidden_states |
torch.Tensor or tuple |
Hidden state tensor from the previous layer. Shape: (batch_size, seq_len, hidden_size). May be a tuple for compatibility with older HF versions.
|
| Input | attention_mask |
torch.Tensor or None |
Attention mask tensor. Passed through to TE's TransformerLayer.forward().
|
| Input | *args, **kwargs |
various | Additional positional and keyword arguments accepted for HF compatibility but not forwarded to TE. |
| Output | return value | torch.Tensor |
Transformed hidden states. Shape: (batch_size, seq_len, hidden_size). Returned directly as a tensor (HF transformers >= 4.57 compatibility).
|
Parameter Mapping
| HF LlamaConfig Field | TE TransformerLayer Parameter | Purpose |
|---|---|---|
config.hidden_size |
hidden_size |
Model hidden dimension |
config.intermediate_size |
ffn_hidden_size |
Feed-forward network inner dimension |
config.num_attention_heads |
num_attention_heads |
Number of attention heads |
config.num_key_value_heads |
num_gqa_groups |
Number of key/value heads for GQA |
config.rms_norm_eps |
layernorm_epsilon |
RMSNorm epsilon for numerical stability |
config.hidden_size // config.num_attention_heads |
RoPE dimension | Head dimension for rotary embedding computation |
config.max_position_embeddings |
RoPE max sequence length | Maximum sequence length for pre-computed RoPE |
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment