Implementation:NVIDIA TransformerEngine CudaRNGStatesTracker
Overview
Concrete tool for tracking CUDA RNG states across distributed processes provided by TransformerEngine.
Description
CudaRNGStatesTracker manages named CUDA RNG states. It allows registering states with add(name, seed) and switching to them via the fork(name) context manager. Used by TE modules during forward pass for reproducible dropout.
The tracker maintains an internal dictionary mapping state names to CUDA RNG state tensors. When fork(name) is entered, the current CUDA RNG state is saved, the named state is loaded, and upon exit the original state is restored. This ensures that only the code within the context manager uses the named RNG state.
Key behaviors:
- State isolation: Each named state is independent; forking one state does not affect others.
- Checkpoint compatibility: States can be serialized via
get_states()and restored viaset_states()for training checkpoint/resume. - Reset capability:
reset()clears all tracked states for reinitialization.
Source
transformer_engine/pytorch/distributed.py, class CudaRNGStatesTracker at L797-907
Import
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
Signature
class CudaRNGStatesTracker:
def __init__(self):
def reset(self) -> None:
def get_states(self) -> Dict[str, torch.Tensor]:
def set_states(self, states: Dict[str, torch.Tensor]) -> None:
def add(self, name: str, seed: int) -> None:
@contextmanager
def fork(self, name: str = "model-parallel-rng"):
I/O
| Direction | Description |
|---|---|
| Input | Name/seed pairs registered via add(name, seed). The seed determines the RNG sequence for that named state.
|
| Output | Context manager returned by fork(name) that switches the CUDA RNG to the named state for the duration of the context.
|
Key Parameters
| Method | Parameters | Description |
|---|---|---|
add(name, seed) |
name (str), seed (int) |
Registers a named RNG state initialized with the given seed. Raises an error if the name already exists. |
fork(name) |
name (str, default "model-parallel-rng") |
Context manager that saves the current CUDA RNG state, activates the named state, and restores the original state on exit. |
get_states() |
none | Returns a dictionary of all named states for checkpointing. |
set_states(states) |
states (Dict[str, torch.Tensor]) |
Restores named states from a checkpoint dictionary. |
reset() |
none | Clears all tracked RNG states. |
Example Usage
import torch
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
# Create and configure the tracker
rng_tracker = CudaRNGStatesTracker()
# Register a model-parallel RNG state with a rank-dependent seed
tp_rank = torch.distributed.get_rank(tp_group)
seed = base_seed + tp_rank
rng_tracker.add("model-parallel-rng", seed)
# Create a getter function for TE modules
def get_rng_tracker():
return rng_tracker
# Pass to TransformerLayer
layer = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=11008,
num_attention_heads=32,
get_rng_state_tracker=get_rng_tracker,
# ... other params
)
# The tracker is used internally during forward pass:
# with rng_tracker.fork("model-parallel-rng"):
# dropout(activations)
Related
Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements