Implementation:NVIDIA TransformerEngine TE Autocast
Appearance
| Field | Value |
|---|---|
| Page Type | Implementation |
| Repository | NVIDIA TransformerEngine |
| Source File | transformer_engine/pytorch/quantization.py (L800-862)
|
| Import | from transformer_engine.pytorch import autocast or import transformer_engine.pytorch as te; te.autocast
|
| Implements | Principle:NVIDIA_TransformerEngine_FP8_Quantization |
| Requires Environment | Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements |
Overview
Concrete tool for enabling FP8 quantized execution of TransformerEngine modules.
Description
te.autocast is a context manager that enables FP8 (or FP4) quantized execution for all TE modules within its scope. It wraps the forward pass and configures global FP8 state. On exit, it triggers amax reduction and scaling factor updates.
When the context manager is entered, it:
- Sets the global FP8 enabled flag so that all TE modules detect FP8 mode.
- Stores the provided recipe for use by individual module forward passes.
- Configures the distributed amax reduction group if provided.
- Optionally enables calibration mode for static scaling recipes.
When the context manager exits, it:
- Triggers amax reduction across the distributed group (if applicable).
- Updates scaling factors based on the collected amax history.
- Restores the previous FP8 global state.
Usage
Use te.autocast to wrap the forward pass of any model built with TransformerEngine modules. It is the primary entry point for enabling FP8 training.
- Wrap only the forward pass inside the context manager; backward is handled automatically.
- Always provide a
Recipeobject to configure the scaling strategy. - In distributed training, pass
amax_reduction_groupto synchronize amax values across ranks.
Code Reference
Source Location
| Attribute | Detail |
|---|---|
| File | transformer_engine/pytorch/quantization.py
|
| Function | autocast
|
| Lines | L800-862 |
Signature
@contextmanager
def autocast(
enabled: bool = True,
calibrating: bool = False,
recipe: Optional[Recipe] = None,
amax_reduction_group: Optional[dist_group_type] = None,
) -> ContextManager:
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
enabled |
bool |
True |
Enable FP8 quantized execution for TE modules within the context. |
calibrating |
bool |
False |
Enable calibration mode for static scaling recipes. In calibration mode, amax values are collected but quantization is not applied. |
recipe |
Optional[Recipe] |
None |
The FP8 recipe object (e.g., DelayedScaling, Float8CurrentScaling) that configures the scaling strategy. If None, a default DelayedScaling recipe is used.
|
amax_reduction_group |
Optional[dist_group_type] |
None |
Distributed process group for synchronizing amax values across ranks. Required for correct FP8 scaling in data-parallel or tensor-parallel training. |
I/O Contract
Input
| Input | Type | Description |
|---|---|---|
enabled |
bool |
Whether to activate FP8 mode. |
calibrating |
bool |
Whether to run in calibration mode. |
recipe |
Optional[Recipe] |
Scaling configuration recipe. |
amax_reduction_group |
Optional[dist_group_type] |
Distributed group for amax synchronization. |
Output
| Output | Type | Description |
|---|---|---|
| Context manager | ContextManager |
A context manager that configures global FP8 state for all TE modules executed within its scope. Does not return a value; the effect is on the global state. |
Side Effects
- Modifies global FP8 state variables that TE modules read during their forward pass.
- On exit, triggers amax reduction across distributed groups.
- On exit, updates scaling factors in the global FP8 metadata based on the recipe configuration.
Usage Examples
Basic FP8 Training
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
recipe = DelayedScaling()
model = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
)
with te.autocast(enabled=True, recipe=recipe):
output = model(input_tensor)
Distributed Training with Amax Reduction
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=1024,
amax_compute_algo="max",
)
# In a distributed training script
dp_group = torch.distributed.group.WORLD
with te.autocast(enabled=True, recipe=recipe, amax_reduction_group=dp_group):
output = model(input_tensor)
loss = loss_fn(output, target)
loss.backward()
Disabling FP8 for Evaluation
import transformer_engine.pytorch as te
# FP8 disabled — TE modules run in default precision
with te.autocast(enabled=False):
output = model(input_tensor)
Related Pages
- Principle:NVIDIA_TransformerEngine_FP8_Quantization -- The principle describing FP8 quantization in TransformerEngine.
- Implementation:NVIDIA_TransformerEngine_DelayedScaling_Recipe -- The delayed scaling recipe passed to
te.autocast. - Implementation:NVIDIA_TransformerEngine_Float8CurrentScaling_Recipe -- The current scaling recipe passed to
te.autocast. - Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
- Environment:NVIDIA_TransformerEngine_Python_PyTorch_Requirements
- Environment:NVIDIA_TransformerEngine_GPU_Compute_Capability
- Heuristic:NVIDIA_TransformerEngine_FP8_Recipe_Auto_Selection
- Heuristic:NVIDIA_TransformerEngine_Sequence_Length_Alignment
- Heuristic:NVIDIA_TransformerEngine_FP8_Checkpoint_Compatibility
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment