Implementation:NVIDIA TransformerEngine Custom Current Scaling
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Reference implementation of the current-scaling (per-tensor) FP8 quantization recipe using the custom recipe API, serving as both a functional example and a test baseline.
Description
current_scaling_ref_quantizer_factory is a factory function that creates CurrentScalingQuantizerRef instances with E4M3 dtype for inputs/weights and E5M2 for outputs/gradients. CurrentScalingTensorRef is a QuantizedTensorStorage subclass storing quantized data, scales, and their transposes with save/restore support for autograd. CurrentScalingQuantizerRef is a Quantizer subclass implementing quantize() (computes per-tensor amax, scales to FP8 range, casts) and qgemm() (dequantizes both operands then performs a high-precision matmul via torch.mm).
Usage
Demonstrates the complete custom recipe API by implementing current-scaling FP8 in pure Python. Provides a reference for correctness testing and a template for custom quantization strategies.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py- Lines
- 1--525
Signature
def current_scaling_ref_quantizer_factory(role): ...
class CurrentScalingTensorRef(QuantizedTensorStorage):
def __init__(self, data, scale, ...): ...
def prepare_for_saving(self) -> list: ...
def restore_from_saved(self, tensors) -> None: ...
class CurrentScalingQuantizerRef(Quantizer):
def __init__(self, fp8_dtype, ...): ...
def quantize(self, tensor, ...): ...
def qgemm(self, a, b, ...): ...
Import
from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import (
current_scaling_ref_quantizer_factory,
CurrentScalingQuantizerRef,
CurrentScalingTensorRef,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| role | str |
Yes | Role identifier ("input", "weight", "grad_output") determining FP8 dtype selection |
| tensor | torch.Tensor |
Yes | High-precision tensor to quantize |
Outputs
| Name | Type | Description |
|---|---|---|
| quantizer | CurrentScalingQuantizerRef |
Quantizer instance for the specified role |
| quantized_storage | CurrentScalingTensorRef |
Storage holding quantized data and scale |
Usage Examples
from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import (
current_scaling_ref_quantizer_factory,
)
import transformer_engine.pytorch as te
# Use the custom recipe with TransformerEngine
recipe = te.Float8CurrentScaling()
recipe.set_quantizer_factory(current_scaling_ref_quantizer_factory)
# The factory creates appropriate quantizers per role
input_quantizer = current_scaling_ref_quantizer_factory("input")
fp8_tensor = input_quantizer.quantize(input_data)