Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Trainer Utils

From Leeroopedia
Revision as of 15:10, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Hpcaitech_ColossalAI_Trainer_Utils.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Distributed Training, Training, Utilities
Last Updated 2026-02-09 00:00 GMT

Overview

Utility functions and classes for the ColossalChat trainer, providing distributed communication helpers, a cycled data loader, a temperature annealing scheduler, and tensor device transfer utilities.

Description

This module provides several utility components used throughout the ColossalChat training infrastructure. AnnealingScheduler implements linear temperature annealing between start and end values with a warmup period. CycledDataLoader wraps a standard DataLoader to automatically restart from the beginning when exhausted. is_rank_0 checks whether the current process is rank 0 in a distributed setup. to_device recursively moves nested tensor structures to a target device using PyTorch's tree_map.

For distributed communication, all_reduce_mean performs an all-reduce sum and divides by the process count (optionally within a plugin's data-parallel group), all_reduce_sum performs an all-reduce sum, and all_gather_tensors gathers tensor lists from all processes and concatenates them.

Usage

These utilities are used internally by ColossalChat trainers and can be imported directly for custom training loops. Use is_rank_0 for conditional logging, to_device for moving batches to GPU, CycledDataLoader for infinite iteration over datasets, and the distributed communication functions for synchronized metric aggregation.

Code Reference

Source Location

Signature

class AnnealingScheduler:
    def __init__(self, start, end, warmup_steps=100, annealing_step=2000): ...
    def get_temperature(self) -> float: ...
    def step_forward(self) -> None: ...

class CycledDataLoader:
    def __init__(self, dataloader: DataLoader) -> None: ...
    def next(self) -> Any: ...

def is_rank_0() -> bool: ...
def to_device(x: Any, device: torch.device) -> Any: ...
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: ...
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: ...
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: ...

Import

from coati.trainer.utils import (
    AnnealingScheduler,
    CycledDataLoader,
    is_rank_0,
    to_device,
    all_reduce_mean,
    all_reduce_sum,
    all_gather_tensors,
)

I/O Contract

Inputs (AnnealingScheduler)

Name Type Required Description
start float Yes Starting temperature value
end float Yes Ending temperature value
warmup_steps int No Number of warmup steps before annealing begins (default 100)
annealing_step int No Total steps for annealing to reach end value (default 2000)

Inputs (CycledDataLoader)

Name Type Required Description
dataloader DataLoader Yes The original data loader to wrap

Inputs (to_device)

Name Type Required Description
x Any Yes Input tensor or nested structure of tensors
device torch.device Yes Target device to move tensors to

Inputs (all_reduce_mean / all_reduce_sum)

Name Type Required Description
tensor torch.Tensor Yes Input tensor to reduce
plugin Plugin No ColossalAI Plugin for data-parallel group scoping (default None)

Inputs (all_gather_tensors)

Name Type Required Description
local_tensor_list torch.Tensor Yes List of local tensors to gather
plugin Plugin No ColossalAI Plugin for data-parallel group scoping (default None)

Outputs

Name Type Description
AnnealingScheduler.get_temperature return float Current temperature value
CycledDataLoader.next return Any Next batch from the data loader
is_rank_0 return bool True if current process is rank 0
to_device return Any Input structure with all tensors moved to device
all_reduce_mean return torch.Tensor Tensor with mean computed across all processes
all_reduce_sum return torch.Tensor Tensor with sum computed across all processes
all_gather_tensors return list Concatenated list of tensors from all processes

Usage Examples

from coati.trainer.utils import (
    is_rank_0, to_device, all_reduce_mean, CycledDataLoader, AnnealingScheduler,
)
import torch

# Conditional logging on rank 0
if is_rank_0():
    print("Training started")

# Move batch to device
batch = to_device(batch, torch.device("cuda:0"))

# Synchronized metric averaging
loss_avg = all_reduce_mean(loss_tensor)

# Infinite data iteration
cycled_loader = CycledDataLoader(train_dataloader)
for step in range(10000):
    batch = cycled_loader.next()

# Temperature annealing for sampling
scheduler = AnnealingScheduler(start=1.0, end=0.1, warmup_steps=50, annealing_step=1000)
for step in range(1000):
    temp = scheduler.get_temperature()
    scheduler.step_forward()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment