Implementation:NVIDIA TransformerEngine Float8CurrentScaling Recipe
| Field | Value |
|---|---|
| Page Type | Implementation |
| Repository | NVIDIA TransformerEngine |
| Source File | transformer_engine/common/recipe/__init__.py (L224-261)
|
| Import | from transformer_engine.common.recipe import Float8CurrentScaling
|
| Implements | Principle:NVIDIA_TransformerEngine_FP8_Current_Scaling |
| Requires Environment | Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements |
Overview
Concrete FP8 recipe configuration for current scaling provided by TransformerEngine.
Description
Float8CurrentScaling computes per-tensor scaling factors from the current iteration's amax. No history buffer is maintained. The class supports optional power-of-2 scales for hardware efficiency.
As a frozen dataclass inheriting from Recipe, Float8CurrentScaling instances are immutable after construction. It provides a simpler configuration surface than DelayedScaling because there is no history buffer or amax computation algorithm to configure.
Usage
Use Float8CurrentScaling when current scaling is preferred over delayed scaling:
- Instantiate with desired parameters (or use defaults).
- Pass the instance to
te.autocast(recipe=recipe). - TE modules will compute per-tensor amax values during the forward pass and derive scaling factors immediately.
Code Reference
Source Location
| Attribute | Detail |
|---|---|
| File | transformer_engine/common/recipe/__init__.py
|
| Class | Float8CurrentScaling
|
| Lines | L224-261 |
| Base Class | Recipe
|
Signature
@dataclass(frozen=True)
class Float8CurrentScaling(Recipe):
fp8_format: Format = Format.HYBRID
use_power_2_scales: bool = <env_default>
fp8_dpa: bool = False
fp8_mha: bool = False
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
fp8_format |
Format |
Format.HYBRID |
The FP8 format to use. HYBRID uses E4M3 for forward and E5M2 for backward. E4M3 and E5M2 use the same format for both passes.
|
use_power_2_scales |
bool |
Environment-dependent default | Whether to round scaling factors to the nearest power of 2. Power-of-2 scales can be implemented as bit shifts, offering hardware efficiency. The default is determined by an environment variable. |
fp8_dpa |
bool |
False |
Enable FP8 execution for dot-product attention (the QK^T computation). |
fp8_mha |
bool |
False |
Enable FP8 execution for the full multi-head attention block, including the attention output projection. |
I/O Contract
Input
| Input | Type | Description |
|---|---|---|
fp8_format |
Format |
FP8 format selection (HYBRID, E4M3, or E5M2). |
use_power_2_scales |
bool |
Whether to use power-of-2 scaling factors. |
fp8_dpa |
bool |
Enable FP8 dot-product attention. |
fp8_mha |
bool |
Enable FP8 multi-head attention. |
Output
| Output | Type | Description |
|---|---|---|
| Recipe object | Float8CurrentScaling |
An immutable recipe instance passed to te.autocast(recipe=...). Consumed by TE modules to configure current-scaling FP8 behavior during the forward and backward passes.
|
Behavioral Notes
- Unlike
DelayedScaling, this recipe does not maintain an amax history buffer. Each iteration's scaling factor is computed fresh from the current tensor. - There is no
marginparameter. The scaling factor isFP8_MAX / amax_currentwithout additional headroom (unless power-of-2 rounding implicitly provides it). - There is no
reduce_amaxparameter. Amax reduction in distributed settings is handled by theamax_reduction_groupparameter inte.autocast.
Usage Examples
Default Configuration
from transformer_engine.common.recipe import Float8CurrentScaling
# All defaults: HYBRID format, environment-default power-of-2 scales
recipe = Float8CurrentScaling()
Explicit HYBRID Format
from transformer_engine.common.recipe import Float8CurrentScaling, Format
recipe = Float8CurrentScaling(fp8_format=Format.HYBRID)
With Power-of-2 Scales
from transformer_engine.common.recipe import Float8CurrentScaling
recipe = Float8CurrentScaling(use_power_2_scales=True)
With FP8 Attention Enabled
from transformer_engine.common.recipe import Float8CurrentScaling
recipe = Float8CurrentScaling(
fp8_dpa=True,
fp8_mha=True,
)
Full Training Loop Example
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8CurrentScaling, Format
recipe = Float8CurrentScaling(
fp8_format=Format.HYBRID,
use_power_2_scales=True,
)
model = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for batch in dataloader:
optimizer.zero_grad()
with te.autocast(enabled=True, recipe=recipe):
output = model(batch["input"])
loss = loss_fn(output, batch["target"])
loss.backward()
optimizer.step()
Related Pages
- Principle:NVIDIA_TransformerEngine_FP8_Current_Scaling -- The principle describing current scaling for FP8.
- Implementation:NVIDIA_TransformerEngine_TE_Autocast -- The context manager that consumes this recipe.
- Implementation:NVIDIA_TransformerEngine_DelayedScaling_Recipe -- The alternative delayed scaling recipe.
- Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
- Environment:NVIDIA_TransformerEngine_GPU_Compute_Capability
- Heuristic:NVIDIA_TransformerEngine_FP8_Recipe_Auto_Selection