Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE Autocast

From Leeroopedia


Field Value
Page Type Implementation
Repository NVIDIA TransformerEngine
Source File transformer_engine/pytorch/quantization.py (L800-862)
Import from transformer_engine.pytorch import autocast or import transformer_engine.pytorch as te; te.autocast
Implements Principle:NVIDIA_TransformerEngine_FP8_Quantization
Requires Environment Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Overview

Concrete tool for enabling FP8 quantized execution of TransformerEngine modules.

Description

te.autocast is a context manager that enables FP8 (or FP4) quantized execution for all TE modules within its scope. It wraps the forward pass and configures global FP8 state. On exit, it triggers amax reduction and scaling factor updates.

When the context manager is entered, it:

  • Sets the global FP8 enabled flag so that all TE modules detect FP8 mode.
  • Stores the provided recipe for use by individual module forward passes.
  • Configures the distributed amax reduction group if provided.
  • Optionally enables calibration mode for static scaling recipes.

When the context manager exits, it:

  • Triggers amax reduction across the distributed group (if applicable).
  • Updates scaling factors based on the collected amax history.
  • Restores the previous FP8 global state.

Usage

Use te.autocast to wrap the forward pass of any model built with TransformerEngine modules. It is the primary entry point for enabling FP8 training.

  • Wrap only the forward pass inside the context manager; backward is handled automatically.
  • Always provide a Recipe object to configure the scaling strategy.
  • In distributed training, pass amax_reduction_group to synchronize amax values across ranks.

Code Reference

Source Location

Attribute Detail
File transformer_engine/pytorch/quantization.py
Function autocast
Lines L800-862

Signature

@contextmanager
def autocast(
    enabled: bool = True,
    calibrating: bool = False,
    recipe: Optional[Recipe] = None,
    amax_reduction_group: Optional[dist_group_type] = None,
) -> ContextManager:

Key Parameters

Parameter Type Default Description
enabled bool True Enable FP8 quantized execution for TE modules within the context.
calibrating bool False Enable calibration mode for static scaling recipes. In calibration mode, amax values are collected but quantization is not applied.
recipe Optional[Recipe] None The FP8 recipe object (e.g., DelayedScaling, Float8CurrentScaling) that configures the scaling strategy. If None, a default DelayedScaling recipe is used.
amax_reduction_group Optional[dist_group_type] None Distributed process group for synchronizing amax values across ranks. Required for correct FP8 scaling in data-parallel or tensor-parallel training.

I/O Contract

Input

Input Type Description
enabled bool Whether to activate FP8 mode.
calibrating bool Whether to run in calibration mode.
recipe Optional[Recipe] Scaling configuration recipe.
amax_reduction_group Optional[dist_group_type] Distributed group for amax synchronization.

Output

Output Type Description
Context manager ContextManager A context manager that configures global FP8 state for all TE modules executed within its scope. Does not return a value; the effect is on the global state.

Side Effects

  • Modifies global FP8 state variables that TE modules read during their forward pass.
  • On exit, triggers amax reduction across distributed groups.
  • On exit, updates scaling factors in the global FP8 metadata based on the recipe configuration.

Usage Examples

Basic FP8 Training

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

recipe = DelayedScaling()
model = te.TransformerLayer(
    hidden_size=1024,
    ffn_hidden_size=4096,
    num_attention_heads=16,
)

with te.autocast(enabled=True, recipe=recipe):
    output = model(input_tensor)

Distributed Training with Amax Reduction

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

recipe = DelayedScaling(
    fp8_format=Format.HYBRID,
    amax_history_len=1024,
    amax_compute_algo="max",
)

# In a distributed training script
dp_group = torch.distributed.group.WORLD

with te.autocast(enabled=True, recipe=recipe, amax_reduction_group=dp_group):
    output = model(input_tensor)
loss = loss_fn(output, target)
loss.backward()

Disabling FP8 for Evaluation

import transformer_engine.pytorch as te

# FP8 disabled — TE modules run in default precision
with te.autocast(enabled=False):
    output = model(input_tensor)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment