Implementation:Hpcaitech ColossalAI AutoModelForCausalLM SFT
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Model_Architecture |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Wrapper for loading pretrained causal language models with lazy initialization and optional LoRA injection for SFT training in ColossalChat.
Description
This implementation combines AutoModelForCausalLM.from_pretrained() from HuggingFace Transformers with ColossalAI's LazyInitContext for memory-efficient model loading. When LoRA is enabled, it uses ColossalChat's convert_to_lora_module() to inject low-rank adapters into the model.
Usage
Use this when loading a model for SFT training. The lazy initialization ensures the model is only materialized on the correct device after Booster wrapping.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/examples/training_scripts/train_sft.py
- Lines: 117-145
Signature
# Model loading pattern in train_sft.py
with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
pretrained: str, # Model path or HuggingFace ID
torch_dtype: torch.dtype, # bf16 or fp16
trust_remote_code: bool = True,
)
if lora_rank > 0:
model = convert_to_lora_module(model, lora_config=LoraConfig)
Import
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.lazy import LazyInitContext
from coati.models.lora import convert_to_lora_module, LoraConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| pretrained | str | Yes | Model path or HuggingFace model ID |
| torch_dtype | torch.dtype | Yes | Model precision (torch.bfloat16 or torch.float16) |
| lora_rank | int | No | LoRA rank (0 disables LoRA). Default: 0 |
| lora_config | str | No | Path to LoRA JSON config file |
Outputs
| Name | Type | Description |
|---|---|---|
| model | nn.Module | Loaded causal LM (optionally with LoRA adapters) |
| tokenizer | PreTrainedTokenizer | Configured tokenizer with special tokens |
Usage Examples
Load Model for SFT
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.lazy import LazyInitContext
from colossalai.accelerator import get_accelerator
# Load with lazy initialization
with LazyInitContext(default_device=get_accelerator().get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment