Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Custom Current Scaling

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Reference implementation of the current-scaling (per-tensor) FP8 quantization recipe using the custom recipe API, serving as both a functional example and a test baseline.

Description

current_scaling_ref_quantizer_factory is a factory function that creates CurrentScalingQuantizerRef instances with E4M3 dtype for inputs/weights and E5M2 for outputs/gradients. CurrentScalingTensorRef is a QuantizedTensorStorage subclass storing quantized data, scales, and their transposes with save/restore support for autograd. CurrentScalingQuantizerRef is a Quantizer subclass implementing quantize() (computes per-tensor amax, scales to FP8 range, casts) and qgemm() (dequantizes both operands then performs a high-precision matmul via torch.mm).

Usage

Demonstrates the complete custom recipe API by implementing current-scaling FP8 in pure Python. Provides a reference for correctness testing and a template for custom quantization strategies.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py
Lines
1--525

Signature

def current_scaling_ref_quantizer_factory(role): ...

class CurrentScalingTensorRef(QuantizedTensorStorage):
    def __init__(self, data, scale, ...): ...
    def prepare_for_saving(self) -> list: ...
    def restore_from_saved(self, tensors) -> None: ...

class CurrentScalingQuantizerRef(Quantizer):
    def __init__(self, fp8_dtype, ...): ...
    def quantize(self, tensor, ...): ...
    def qgemm(self, a, b, ...): ...

Import

from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import (
    current_scaling_ref_quantizer_factory,
    CurrentScalingQuantizerRef,
    CurrentScalingTensorRef,
)

I/O Contract

Inputs

Name Type Required Description
role str Yes Role identifier ("input", "weight", "grad_output") determining FP8 dtype selection
tensor torch.Tensor Yes High-precision tensor to quantize

Outputs

Name Type Description
quantizer CurrentScalingQuantizerRef Quantizer instance for the specified role
quantized_storage CurrentScalingTensorRef Storage holding quantized data and scale

Usage Examples

from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import (
    current_scaling_ref_quantizer_factory,
)
import transformer_engine.pytorch as te

# Use the custom recipe with TransformerEngine
recipe = te.Float8CurrentScaling()
recipe.set_quantizer_factory(current_scaling_ref_quantizer_factory)

# The factory creates appropriate quantizers per role
input_quantizer = current_scaling_ref_quantizer_factory("input")
fp8_tensor = input_quantizer.quantize(input_data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment