Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine Float8CurrentScaling Recipe

From Leeroopedia


Field Value
Page Type Implementation
Repository NVIDIA TransformerEngine
Source File transformer_engine/common/recipe/__init__.py (L224-261)
Import from transformer_engine.common.recipe import Float8CurrentScaling
Implements Principle:NVIDIA_TransformerEngine_FP8_Current_Scaling
Requires Environment Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Overview

Concrete FP8 recipe configuration for current scaling provided by TransformerEngine.

Description

Float8CurrentScaling computes per-tensor scaling factors from the current iteration's amax. No history buffer is maintained. The class supports optional power-of-2 scales for hardware efficiency.

As a frozen dataclass inheriting from Recipe, Float8CurrentScaling instances are immutable after construction. It provides a simpler configuration surface than DelayedScaling because there is no history buffer or amax computation algorithm to configure.

Usage

Use Float8CurrentScaling when current scaling is preferred over delayed scaling:

  • Instantiate with desired parameters (or use defaults).
  • Pass the instance to te.autocast(recipe=recipe).
  • TE modules will compute per-tensor amax values during the forward pass and derive scaling factors immediately.

Code Reference

Source Location

Attribute Detail
File transformer_engine/common/recipe/__init__.py
Class Float8CurrentScaling
Lines L224-261
Base Class Recipe

Signature

@dataclass(frozen=True)
class Float8CurrentScaling(Recipe):
    fp8_format: Format = Format.HYBRID
    use_power_2_scales: bool = <env_default>
    fp8_dpa: bool = False
    fp8_mha: bool = False

Key Parameters

Parameter Type Default Description
fp8_format Format Format.HYBRID The FP8 format to use. HYBRID uses E4M3 for forward and E5M2 for backward. E4M3 and E5M2 use the same format for both passes.
use_power_2_scales bool Environment-dependent default Whether to round scaling factors to the nearest power of 2. Power-of-2 scales can be implemented as bit shifts, offering hardware efficiency. The default is determined by an environment variable.
fp8_dpa bool False Enable FP8 execution for dot-product attention (the QK^T computation).
fp8_mha bool False Enable FP8 execution for the full multi-head attention block, including the attention output projection.

I/O Contract

Input

Input Type Description
fp8_format Format FP8 format selection (HYBRID, E4M3, or E5M2).
use_power_2_scales bool Whether to use power-of-2 scaling factors.
fp8_dpa bool Enable FP8 dot-product attention.
fp8_mha bool Enable FP8 multi-head attention.

Output

Output Type Description
Recipe object Float8CurrentScaling An immutable recipe instance passed to te.autocast(recipe=...). Consumed by TE modules to configure current-scaling FP8 behavior during the forward and backward passes.

Behavioral Notes

  • Unlike DelayedScaling, this recipe does not maintain an amax history buffer. Each iteration's scaling factor is computed fresh from the current tensor.
  • There is no margin parameter. The scaling factor is FP8_MAX / amax_current without additional headroom (unless power-of-2 rounding implicitly provides it).
  • There is no reduce_amax parameter. Amax reduction in distributed settings is handled by the amax_reduction_group parameter in te.autocast.

Usage Examples

Default Configuration

from transformer_engine.common.recipe import Float8CurrentScaling

# All defaults: HYBRID format, environment-default power-of-2 scales
recipe = Float8CurrentScaling()

Explicit HYBRID Format

from transformer_engine.common.recipe import Float8CurrentScaling, Format

recipe = Float8CurrentScaling(fp8_format=Format.HYBRID)

With Power-of-2 Scales

from transformer_engine.common.recipe import Float8CurrentScaling

recipe = Float8CurrentScaling(use_power_2_scales=True)

With FP8 Attention Enabled

from transformer_engine.common.recipe import Float8CurrentScaling

recipe = Float8CurrentScaling(
    fp8_dpa=True,
    fp8_mha=True,
)

Full Training Loop Example

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8CurrentScaling, Format

recipe = Float8CurrentScaling(
    fp8_format=Format.HYBRID,
    use_power_2_scales=True,
)

model = te.TransformerLayer(
    hidden_size=1024,
    ffn_hidden_size=4096,
    num_attention_heads=16,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for batch in dataloader:
    optimizer.zero_grad()
    with te.autocast(enabled=True, recipe=recipe):
        output = model(batch["input"])
    loss = loss_fn(output, batch["target"])
    loss.backward()
    optimizer.step()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment