Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine CudaRNGStatesTracker

From Leeroopedia


Overview

Concrete tool for tracking CUDA RNG states across distributed processes provided by TransformerEngine.

Description

CudaRNGStatesTracker manages named CUDA RNG states. It allows registering states with add(name, seed) and switching to them via the fork(name) context manager. Used by TE modules during forward pass for reproducible dropout.

The tracker maintains an internal dictionary mapping state names to CUDA RNG state tensors. When fork(name) is entered, the current CUDA RNG state is saved, the named state is loaded, and upon exit the original state is restored. This ensures that only the code within the context manager uses the named RNG state.

Key behaviors:

  • State isolation: Each named state is independent; forking one state does not affect others.
  • Checkpoint compatibility: States can be serialized via get_states() and restored via set_states() for training checkpoint/resume.
  • Reset capability: reset() clears all tracked states for reinitialization.

Source

transformer_engine/pytorch/distributed.py, class CudaRNGStatesTracker at L797-907

Import

from transformer_engine.pytorch.distributed import CudaRNGStatesTracker

Signature

class CudaRNGStatesTracker:
    def __init__(self):

    def reset(self) -> None:

    def get_states(self) -> Dict[str, torch.Tensor]:

    def set_states(self, states: Dict[str, torch.Tensor]) -> None:

    def add(self, name: str, seed: int) -> None:

    @contextmanager
    def fork(self, name: str = "model-parallel-rng"):

I/O

Direction Description
Input Name/seed pairs registered via add(name, seed). The seed determines the RNG sequence for that named state.
Output Context manager returned by fork(name) that switches the CUDA RNG to the named state for the duration of the context.

Key Parameters

Method Parameters Description
add(name, seed) name (str), seed (int) Registers a named RNG state initialized with the given seed. Raises an error if the name already exists.
fork(name) name (str, default "model-parallel-rng") Context manager that saves the current CUDA RNG state, activates the named state, and restores the original state on exit.
get_states() none Returns a dictionary of all named states for checkpointing.
set_states(states) states (Dict[str, torch.Tensor]) Restores named states from a checkpoint dictionary.
reset() none Clears all tracked RNG states.

Example Usage

import torch
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker

# Create and configure the tracker
rng_tracker = CudaRNGStatesTracker()

# Register a model-parallel RNG state with a rank-dependent seed
tp_rank = torch.distributed.get_rank(tp_group)
seed = base_seed + tp_rank
rng_tracker.add("model-parallel-rng", seed)

# Create a getter function for TE modules
def get_rng_tracker():
    return rng_tracker

# Pass to TransformerLayer
layer = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=11008,
    num_attention_heads=32,
    get_rng_state_tracker=get_rng_tracker,
    # ... other params
)

# The tracker is used internally during forward pass:
# with rng_tracker.fork("model-parallel-rng"):
#     dropout(activations)

Related

Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment