Workflow:Hpcaitech ColossalAI LLaMA Continual Pretraining
| Knowledge Sources | |
|---|---|
| Domains | LLMs, Pretraining, Distributed_Training |
| Last Updated | 2026-02-09 03:00 GMT |
Overview
End-to-end process for continual pretraining and supervised fine-tuning of LLaMA models on custom domain data using ColossalAI's distributed training infrastructure.
Description
This workflow covers the continual pretraining pipeline for LLaMA models, enabling domain adaptation through additional training on custom text corpora. It supports both continual pretraining (next-token prediction on raw text) and supervised fine-tuning (instruction following) modes. The pipeline includes tokenizer vocabulary expansion for non-English languages, spliced tokenization for efficient long-document training, and NEFTune noise injection for improved generalization. Training is distributed across multiple GPUs using ColossalAI's plugin system with support for ZeRO-2, Gemini, and 3D parallelism strategies.
Usage
Execute this workflow when you need to adapt a base LLaMA model to a new domain (e.g., Chinese language, legal, medical) through continual pretraining on domain-specific text data, or when you want to fine-tune the adapted model on instruction-following tasks. This is the first stage in a multi-stage training pipeline before alignment.
Execution Steps
Step 1: Tokenizer Preparation
Optionally expand the LLaMA tokenizer vocabulary with domain-specific tokens. This is particularly important for non-English languages where the default LLaMA tokenizer has limited coverage.
Key considerations:
- Chinese language support requires adding Chinese tokens to the vocabulary
- Vocabulary expansion requires adjusting the model's embedding layer dimensions
- The init_tokenizer utility handles vocabulary merging and saving
Step 2: Dataset Preparation
Preprocess raw text data into tokenized format using the spliced tokenization pipeline. For pretraining, raw JSONL files are tokenized into concatenated sequences up to max_length. For SFT, instruction-response pairs are formatted using conversation templates.
What happens:
- For pretraining: run prepare_pretrain_dataset.py to tokenize raw JSONL into Arrow format with spliced sequences
- For SFT: run prepare_sft_dataset.py to apply conversation templates and tokenize
- Data is split into multiple shards for parallel loading across workers
Step 3: Environment and Model Initialization
Initialize the distributed training environment and load the pretrained LLaMA model. Configure gradient checkpointing and Flash Attention for memory and compute efficiency.
What happens:
- Initialize ColossalAI distributed environment with launch_from_torch()
- Load model with AutoModelForCausalLM.from_pretrained()
- Enable gradient checkpointing to reduce activation memory
- Enable Flash Attention 2 for efficient attention computation
- Optionally apply LoRA for parameter-efficient continual pretraining
Step 4: Plugin and Booster Configuration
Select the parallelism strategy and create the Booster that wraps the model, optimizer, dataloader, and learning rate scheduler for distributed training.
Available strategies:
- ZeRO Stage 2: Shards optimizer states and gradients
- Gemini: CPU-GPU heterogeneous memory management
- 3D Parallelism: Tensor + Pipeline + Data parallelism combined
- DDP: Standard distributed data parallelism
Step 5: Training Loop
Execute the main training loop with next-token prediction loss (cross-entropy). The loop supports gradient accumulation, periodic checkpoint saving, and TensorBoard logging.
What happens:
- For standard plugins: iterate through batches, compute loss, backward via booster
- For pipeline parallelism: use booster.execute_pipeline() for staged forward/backward
- Accumulate gradients over configurable number of steps
- Apply NEFTune noise injection to embeddings during training for regularization
- Log loss metrics to TensorBoard at each step
- Save checkpoints at configurable intervals with full training state
Step 6: Checkpoint Saving
Save the final trained model with sharded weights, along with optimizer state and learning rate scheduler for potential training resumption.
Key considerations:
- Checkpoints include model, optimizer, scheduler, epoch, and step metadata
- Supports resuming from any saved checkpoint
- Model can be saved in HuggingFace-compatible format