Implementation:CarperAI Trlx Convert LLaMA To NeMo
| Knowledge Sources | |
|---|---|
| Domains | Model_Conversion, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for converting HuggingFace LLaMA model weights into NVIDIA NeMo/Megatron-GPT checkpoint format with tensor model parallelism support.
Description
This conversion script loads a HuggingFace LLaMA model (AutoModelForCausalLM), maps its weight tensors to NeMo/Megatron naming conventions, slices embedding, attention (Q/K/V), MLP, and output layers across tensor-parallel (TP) ranks, and saves individual checkpoint files per TP rank. It also generates a corresponding NeMo YAML config file from a template. Internal helper functions handle the tensor slicing: get_self_attention_weight partitions QKV projections along the output dimension, and get_mlp_weight partitions MLP layers across TP ranks.
Usage
Use this script as a CLI tool before NeMo training to convert HuggingFace LLaMA checkpoints into the NeMo format required by the NeMo PPO/ILQL/SFT trainers. Required when using LLaMA models with NeMo's Megatron backend.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: examples/llama_nemo/convert_llama_to_nemo.py
- Lines: 1-144
Signature
def main(args: argparse.Namespace) -> None:
"""
Convert HuggingFace LLaMA weights to NeMo Megatron format.
Args:
args.model_path: Path to HuggingFace LLaMA model.
args.output_folder: Directory for NeMo checkpoint output.
args.total_tp: Number of tensor-parallel ranks.
args.name: Name for config/output files.
"""
Import
# CLI usage:
# python examples/llama_nemo/convert_llama_to_nemo.py \
# --model_path /path/to/llama \
# --output_folder /path/to/output \
# --total_tp 4 \
# --name llama2_7b
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| --model_path | str (CLI) | Yes | Path to HuggingFace LLaMA model directory |
| --output_folder | str (CLI) | Yes | Output directory for NeMo checkpoint |
| --total_tp | int (CLI) | Yes | Number of tensor-parallel ranks to shard across |
| --name | str (CLI) | Yes | Name for output config and checkpoint files |
Outputs
| Name | Type | Description |
|---|---|---|
| Checkpoint files | .pt files | One state dict per TP rank at {output_folder}/tp_rank_{i}/model_weights.pt |
| Config file | .yaml | NeMo config at {output_folder}/{name}.yaml |
Usage Examples
Convert LLaMA 2 7B to NeMo with TP=4
python examples/llama_nemo/convert_llama_to_nemo.py \
--model_path /models/llama-2-7b-hf \
--output_folder /nemo_checkpoints/llama2_7b \
--total_tp 4 \
--name llama2_7b
# Output structure:
# /nemo_checkpoints/llama2_7b/
# llama2_7b.yaml
# tp_rank_0/model_weights.pt
# tp_rank_1/model_weights.pt
# tp_rank_2/model_weights.pt
# tp_rank_3/model_weights.pt