Implementation:NVIDIA TransformerEngine TEGemmaForCausalLM
Overview
Complete TE-accelerated Gemma model with FP8 training and generation support.
Description
TEGemmaForCausalLM extends HuggingFace's GemmaForCausalLM with TransformerEngine decoder layers, a custom generate() method supporting InferenceParams-based KV caching and optional CUDA graphs, a forward() method for training and calibration, and hardware-adaptive FP8 recipe selection.
The class monkey-patches GemmaDecoderLayer with TEGemmaDecoderLayer during model construction via the replace_decoder() context manager. After construction, it creates two wrapper objects for inference:
_model_context_phase(GemmaModelWrapper): Handles full-sequence prefill, applying"padding_causal"attention masking_model_generation_phase(GemmaGenerationWrapper): Handles single-token decode, applying"padding"attention masking and argmax token selection
A subclass TEGemmaForCausalLMCudaGraphs further extends this class by pre-allocating static buffers and capturing the two phases as CUDA graph callables via te.pytorch.make_graphed_callables().
This is a Wrapper Doc.
Source
docs/examples/te_gemma/te_gemma.py, class TEGemmaForCausalLM at lines 305-558.
Signature
class TEGemmaForCausalLM(GemmaForCausalLM):
def __init__(self, config: GemmaConfig):
dtype = torch.bfloat16
with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer):
super().__init__(config)
self.config = config
self.to(dtype).cuda()
self.hidden_size = config.hidden_size
self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head)
self._model_generation_phase = GemmaGenerationWrapper(
lm_head=self.lm_head,
model=self.model,
dtype=dtype,
)
if self.config.fp8:
self.fp8_recipe = get_default_fp8_recipe()
self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)(
max_seq_len=self.config.max_position_embeddings
).cuda()
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
pad_token_id: int = 0,
max_new_tokens: int = 0,
*args,
**kwargs,
):
"""
Autoregressive generation with KV cache.
1. Shifts left-padded inputs to right-padded
2. Creates InferenceParams for KV caching
3. Runs context/prefill phase
4. Loops max_new_tokens times for generation phase
5. Returns [batch, input_len + max_new_tokens + 1] tensor
"""
...
def forward(self, *args, **kwargs):
"""
Training/calibration forward pass.
Sets inference_params=None, applies padding_causal mask,
runs through model layers and lm_head.
"""
...
I/O
Constructor Input:
config:GemmaConfig-- HuggingFace Gemma configuration, extended with TE-specific fields:config.fp8:bool-- Whether to enable FP8config.fuse_qkv_params:bool-- Whether to fuse QKV projectionsconfig.is_paged:bool-- Whether to use paged KV cacheconfig.max_seq_length:int-- Maximum sequence length for inferenceconfig.generation_cuda_graphs:bool-- Whether CUDA graphs are enabledconfig.cuda_graphs_static_batch_size:int-- Static batch size for CUDA graphsconfig.cuda_graphs_static_max_context_len:int-- Static max context length for CUDA graphs
Constructor Output:
TEGemmaForCausalLMinstance on CUDA withbfloat16precision
generate() Input:
input_ids:torch.Tensor-- Input token IDs of shape[batch, seq_len], left-padded withpad_token_idpad_token_id:int-- Padding token ID (default:0)max_new_tokens:int-- Number of new tokens to generate (default:0)
generate() Output:
torch.Tensor-- Generated token tensor of shape[batch, input_len + max_new_tokens + 1]
forward() Input:
input_ids:torch.Tensor-- Input token IDs (via kwargs)
forward() Output:
torch.Tensor-- Logits tensor of shape[batch, seq_len, vocab_size]
Key Components
| Component | Type | Purpose |
|---|---|---|
_model_context_phase |
GemmaModelWrapper |
Runs prefill: embeds input, iterates all layers with "padding_causal" mask, applies final norm and lm_head
|
_model_generation_phase |
GemmaGenerationWrapper |
Runs decode: processes single token, applies "padding" mask, selects next token via argmax
|
te_rope_emb |
torch.Tensor |
Pre-computed rotary position embeddings shared across all layers |
fp8_recipe |
DelayedScaling |
FP8 recipe selected based on hardware capabilities (only when config.fp8=True)
|
inference_params |
InferenceParams |
KV cache manager created per generate() call (or static for CUDA graphs)
|
Generation Flow
The generate() method follows this sequence:
- Padding shift: Left-padded HF-style inputs are shifted to right-padded format for TE attention mask compatibility
- InferenceParams creation: KV cache is allocated with batch size, max sequence length, KV head count, and head dimensions
- Context/prefill phase: Full input sequence is processed, K/V cached, last-token logits extracted, first generated token selected
- Generation loop: For each of
max_new_tokensiterations:- Single-token forward pass through
_model_generation_phase - New K/V appended to cache via
inference_params.pre_step()and internalstep() - Next token selected via argmax
- Single-token forward pass through
- Result assembly: Input tokens concatenated with all generated tokens
Both torch.amp.autocast (for BF16) and te.pytorch.autocast (for FP8) are applied during generation.
Notes
- The model is moved to CUDA and cast to
bfloat16in the constructor. _model_context_phaseand_model_generation_phaseshare the same underlying model layers andlm_head-- they are wrappers, not separate models.- The
forward()method setsinference_params=None, making it suitable for training and FP8 calibration but not generation. - The
_padding_to_end()static method handles the conversion from HF-style left-padding to TE-compatible right-padding.
Related
Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements