Implementation:Lm sys FastChat Get Peft State Maybe Zero 3
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training, Model Persistence |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
API for extracting LoRA adapter parameters from a trained model, with transparent handling of DeepSpeed ZeRO-3 partitioned parameters and configurable bias saving strategies.
Description
FastChat's train_lora.py defines two helper functions and a save logic block for persisting LoRA adapter weights after training:
maybe_zero_3(param)-- A utility that transparently handles parameter gathering for ZeRO-3. If the parameter has ads_idattribute (indicating it is a DeepSpeed-managed partitioned parameter), it useszero.GatheredParameters()to all-gather the full tensor before detaching, moving to CPU, and cloning. For non-ZeRO-3 parameters, it simply detaches and clones to CPU.get_peft_state_maybe_zero_3(named_params, bias)-- Adapted frompeft.utils.get_peft_model_state_dict, this function filters the model's named parameters to extract only LoRA-related weights. It supports three bias modes:"none"(onlylora_keys),"all"(lora_andbiaskeys), and"lora_only"(LoRA keys plus biases from LoRA-modified layers only). Each extracted parameter is passed throughmaybe_zero_3()for safe gathering.- Save logic (lines 201-218) -- The actual save path branches on whether ZeRO-3 is active:
- ZeRO-3: Uses
trainer.model_wrapped._zero3_consolidated_16bit_state_dict()to gather all parameters efficiently, then passes the full state dict tomodel.save_pretrained(). PEFT internally filters to LoRA-only keys. - Non-ZeRO-3: Uses
get_peft_state_maybe_zero_3()to extract only LoRA parameters, then passes this filtered state dict tomodel.save_pretrained(). - Both paths gate the file write on
local_rank == 0to prevent concurrent writes.
- ZeRO-3: Uses
The output consists of adapter_model.bin (serialized LoRA weights) and adapter_config.json (LoRA hyperparameters) in the specified output directory.
Usage
Use this function when saving LoRA adapter weights after training in FastChat, especially when training may use DeepSpeed ZeRO-3 and parameters need gathering.
Code Reference
Source Location
- Repository: FastChat
- File:
fastchat/train/train_lora.py(lines 68-75,maybe_zero_3function) - File:
fastchat/train/train_lora.py(lines 79-101,get_peft_state_maybe_zero_3function) - File:
fastchat/train/train_lora.py(lines 201-218, save logic)
Signature
# maybe_zero_3 helper (lines 68-75)
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# get_peft_state_maybe_zero_3 (lines 79-101)
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
# Save logic (lines 201-218)
if deepspeed.is_deepspeed_zero3_enabled():
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
if training_args.local_rank == 0:
state_dict = state_dict_zero3
else:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), lora_args.lora_bias
)
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
Import
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import deepspeed
# For reuse in other scripts:
from fastchat.train.train_lora import get_peft_state_maybe_zero_3
I/O Contract
Inputs (get_peft_state_maybe_zero_3)
| Name | Type | Required | Description |
|---|---|---|---|
| named_params | Iterator[Tuple[str, Tensor]] |
Yes | Output of model.named_parameters(), yielding (name, parameter) pairs
|
| bias | str |
Yes | Bias saving strategy: "none", "all", or "lora_only"
|
Inputs (maybe_zero_3)
| Name | Type | Required | Description |
|---|---|---|---|
| param | torch.Tensor or DeepSpeedParameter |
Yes | A model parameter, potentially partitioned under ZeRO-3 |
Outputs
| Name | Type | Description |
|---|---|---|
| state_dict (from get_peft_state_maybe_zero_3) | dict[str, Tensor] |
Dictionary of LoRA parameter names to their CPU-resident, fully-gathered tensor values |
| param (from maybe_zero_3) | torch.Tensor |
CPU-resident, detached clone of the (possibly gathered) parameter |
| adapter_model.bin | file |
Serialized LoRA adapter weights saved to output_dir
|
| adapter_config.json | file |
LoRA configuration JSON saved to output_dir
|
Usage Examples
Extracting LoRA State Dict (Non-ZeRO-3)
from fastchat.train.train_lora import get_peft_state_maybe_zero_3
# After training, extract only LoRA parameters
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), bias="none"
)
# Save adapter weights
model.save_pretrained("output_lora/", state_dict=state_dict)
# Creates: output_lora/adapter_model.bin, output_lora/adapter_config.json
Saving with ZeRO-3 Consolidated State Dict
from transformers import deepspeed
if deepspeed.is_deepspeed_zero3_enabled():
# Efficient bulk all-gather of all parameters
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
if training_args.local_rank == 0:
# PEFT's save_pretrained automatically filters to LoRA keys
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
Gathering a Single ZeRO-3 Parameter
from fastchat.train.train_lora import maybe_zero_3
# Gather a potentially partitioned parameter to CPU
for name, param in model.named_parameters():
if "lora_A" in name:
full_param = maybe_zero_3(param)
print(f"{name}: shape={full_param.shape}, device={full_param.device}")
Post-Training Merge with apply_lora
python3 -m fastchat.model.apply_lora \
--base-model-path meta-llama/Llama-2-7b-hf \
--target-model-path output_merged/ \
--lora-path output_lora/