Implementation:Hiyouga LLaMA Factory V1 Accelerator Interface
| Knowledge Sources | |
|---|---|
| Domains | Distributed Training, Model Parallelism, Data Parallelism |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
DistributedInterface is a singleton class that provides a unified API for model parallelism, data parallelism, and collective communication operations.
Description
The module implements three core types: Dim (an enum of parallelism dimensions: MP_REPLICATE, MP_SHARD, DP, CP), DistributedStrategy (a dataclass that validates mesh dimension sizes against world size), and DistributedInterface (a singleton that initializes process groups and device meshes using torch.distributed.device_mesh). The interface exposes high-level methods for all_gather, all_reduce, broadcast, sync, and barrier across specified parallelism dimensions. It supports both model parallelism (replicate/shard) and data parallelism (DP/context parallelism) through separate device meshes.
Usage
Use DistributedInterface as the central coordination point for all distributed training in LLaMA-Factory V1. Instantiate it once with an optional DistributedConfig to configure parallelism strategy. All subsequent instantiations return the same singleton. Access collective operations through its methods, specifying the parallelism dimension to operate on.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/v1/accelerator/interface.py
- Lines: 1-260
Signature
class Dim(StrEnum):
MP_REPLICATE = "mp_replicate"
MP_SHARD = "mp_shard"
DP = "dp"
CP = "cp"
@dataclass
class DistributedStrategy:
mp_replicate_size: int = 1
mp_shard_size: int | None = None
dp_size: int | None = None
cp_size: int = 1
class DistributedInterface:
def __init__(self, config: DistributedConfig | None = None) -> None
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]
def get_rank(self, dim: Dim | None = None) -> int
def get_world_size(self, dim: Dim | None = None) -> int
def get_local_rank(self) -> int
def get_local_world_size(self) -> int
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike
def all_reduce(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Dim | None = Dim.DP) -> TensorLike
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike
def sync(self) -> None
def barrier(self) -> None
def destroy(self) -> None
Import
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy, Dim
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config (DistributedInterface) | DistributedConfig or None | No | Configuration dict with mp_replicate_size, mp_shard_size, dp_size, cp_size, timeout keys |
| mp_replicate_size (DistributedStrategy) | int | No | Model parallel replicate size (default 1) |
| mp_shard_size (DistributedStrategy) | int or None | No | Model parallel shard size (default world_size // mp_replicate_size) |
| dp_size (DistributedStrategy) | int or None | No | Data parallel size (default world_size // cp_size) |
| cp_size (DistributedStrategy) | int | No | Context parallel size (default 1) |
| dim (collective ops) | Dim or None | No | Parallelism dimension to operate on (default Dim.DP) |
| data (collective ops) | TensorLike | Yes | Input tensor, numpy array, or scalar for collective operations |
Outputs
| Name | Type | Description |
|---|---|---|
| DistributedInterface instance | DistributedInterface | Singleton instance with initialized process groups and device meshes |
| get_device_mesh result | DeviceMesh or None | PyTorch DeviceMesh for the specified dimension; None if not distributed |
| get_rank result | int | Rank within the specified parallelism dimension |
| get_world_size result | int | World size within the specified parallelism dimension |
| all_gather result | TensorLike | Gathered data from all ranks in the specified group |
| all_reduce result | TensorLike | Reduced data across all ranks in the specified group |
Usage Examples
from llamafactory.v1.accelerator.interface import DistributedInterface, Dim
# Initialize singleton (first call configures, subsequent calls return same instance)
dist_interface = DistributedInterface(config={
"mp_replicate_size": 1,
"mp_shard_size": 4,
"dp_size": 2,
"cp_size": 1,
"timeout": 18000,
})
# Query rank within data parallel group
dp_rank = dist_interface.get_rank(dim=Dim.DP)
dp_world = dist_interface.get_world_size(dim=Dim.DP)
# Collective operations
import torch
loss = torch.tensor(2.5, device=dist_interface.current_device)
avg_loss = dist_interface.all_reduce(loss, dim=Dim.DP)
# Synchronize and cleanup
dist_interface.barrier()
dist_interface.destroy()
Related Pages
- Hiyouga_LLaMA_Factory_V1_Accelerator_Helper - Low-level helper functions consumed by DistributedInterface
- Hiyouga_LLaMA_Factory_V1_Utils_Types - DistributedConfig, TensorLike, and ProcessGroup type definitions
- Hiyouga_LLaMA_Factory_V1_Trainer_Plugins_FSDP2 - FSDP2 plugin that uses DistributedInterface for sharded training
- Hiyouga_LLaMA_Factory_V1_Trainer_Plugins_DeepSpeed - DeepSpeed plugin that uses DistributedInterface