Overview
Utility functions and classes for the ColossalChat trainer, providing distributed communication helpers, a cycled data loader, a temperature annealing scheduler, and tensor device transfer utilities.
Description
This module provides several utility components used throughout the ColossalChat training infrastructure. AnnealingScheduler implements linear temperature annealing between start and end values with a warmup period. CycledDataLoader wraps a standard DataLoader to automatically restart from the beginning when exhausted. is_rank_0 checks whether the current process is rank 0 in a distributed setup. to_device recursively moves nested tensor structures to a target device using PyTorch's tree_map.
For distributed communication, all_reduce_mean performs an all-reduce sum and divides by the process count (optionally within a plugin's data-parallel group), all_reduce_sum performs an all-reduce sum, and all_gather_tensors gathers tensor lists from all processes and concatenates them.
Usage
These utilities are used internally by ColossalChat trainers and can be imported directly for custom training loops. Use is_rank_0 for conditional logging, to_device for moving batches to GPU, CycledDataLoader for infinite iteration over datasets, and the distributed communication functions for synchronized metric aggregation.
Code Reference
Source Location
Signature
class AnnealingScheduler:
def __init__(self, start, end, warmup_steps=100, annealing_step=2000): ...
def get_temperature(self) -> float: ...
def step_forward(self) -> None: ...
class CycledDataLoader:
def __init__(self, dataloader: DataLoader) -> None: ...
def next(self) -> Any: ...
def is_rank_0() -> bool: ...
def to_device(x: Any, device: torch.device) -> Any: ...
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: ...
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: ...
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: ...
Import
from coati.trainer.utils import (
AnnealingScheduler,
CycledDataLoader,
is_rank_0,
to_device,
all_reduce_mean,
all_reduce_sum,
all_gather_tensors,
)
I/O Contract
Inputs (AnnealingScheduler)
| Name |
Type |
Required |
Description
|
| start |
float |
Yes |
Starting temperature value
|
| end |
float |
Yes |
Ending temperature value
|
| warmup_steps |
int |
No |
Number of warmup steps before annealing begins (default 100)
|
| annealing_step |
int |
No |
Total steps for annealing to reach end value (default 2000)
|
Inputs (CycledDataLoader)
| Name |
Type |
Required |
Description
|
| dataloader |
DataLoader |
Yes |
The original data loader to wrap
|
Inputs (to_device)
| Name |
Type |
Required |
Description
|
| x |
Any |
Yes |
Input tensor or nested structure of tensors
|
| device |
torch.device |
Yes |
Target device to move tensors to
|
Inputs (all_reduce_mean / all_reduce_sum)
| Name |
Type |
Required |
Description
|
| tensor |
torch.Tensor |
Yes |
Input tensor to reduce
|
| plugin |
Plugin |
No |
ColossalAI Plugin for data-parallel group scoping (default None)
|
Inputs (all_gather_tensors)
| Name |
Type |
Required |
Description
|
| local_tensor_list |
torch.Tensor |
Yes |
List of local tensors to gather
|
| plugin |
Plugin |
No |
ColossalAI Plugin for data-parallel group scoping (default None)
|
Outputs
| Name |
Type |
Description
|
| AnnealingScheduler.get_temperature return |
float |
Current temperature value
|
| CycledDataLoader.next return |
Any |
Next batch from the data loader
|
| is_rank_0 return |
bool |
True if current process is rank 0
|
| to_device return |
Any |
Input structure with all tensors moved to device
|
| all_reduce_mean return |
torch.Tensor |
Tensor with mean computed across all processes
|
| all_reduce_sum return |
torch.Tensor |
Tensor with sum computed across all processes
|
| all_gather_tensors return |
list |
Concatenated list of tensors from all processes
|
Usage Examples
from coati.trainer.utils import (
is_rank_0, to_device, all_reduce_mean, CycledDataLoader, AnnealingScheduler,
)
import torch
# Conditional logging on rank 0
if is_rank_0():
print("Training started")
# Move batch to device
batch = to_device(batch, torch.device("cuda:0"))
# Synchronized metric averaging
loss_avg = all_reduce_mean(loss_tensor)
# Infinite data iteration
cycled_loader = CycledDataLoader(train_dataloader)
for step in range(10000):
batch = cycled_loader.next()
# Temperature annealing for sampling
scheduler = AnnealingScheduler(start=1.0, end=0.1, warmup_steps=50, annealing_step=1000)
for step in range(1000):
temp = scheduler.get_temperature()
scheduler.step_forward()
Related Pages