Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lm sys FastChat Get Peft State Maybe Zero 3

From Leeroopedia


Knowledge Sources
Domains NLP, Training, Model Persistence
Last Updated 2026-02-07 14:00 GMT

Overview

API for extracting LoRA adapter parameters from a trained model, with transparent handling of DeepSpeed ZeRO-3 partitioned parameters and configurable bias saving strategies.

Description

FastChat's train_lora.py defines two helper functions and a save logic block for persisting LoRA adapter weights after training:

  1. maybe_zero_3(param) -- A utility that transparently handles parameter gathering for ZeRO-3. If the parameter has a ds_id attribute (indicating it is a DeepSpeed-managed partitioned parameter), it uses zero.GatheredParameters() to all-gather the full tensor before detaching, moving to CPU, and cloning. For non-ZeRO-3 parameters, it simply detaches and clones to CPU.
  2. get_peft_state_maybe_zero_3(named_params, bias) -- Adapted from peft.utils.get_peft_model_state_dict, this function filters the model's named parameters to extract only LoRA-related weights. It supports three bias modes: "none" (only lora_ keys), "all" (lora_ and bias keys), and "lora_only" (LoRA keys plus biases from LoRA-modified layers only). Each extracted parameter is passed through maybe_zero_3() for safe gathering.
  3. Save logic (lines 201-218) -- The actual save path branches on whether ZeRO-3 is active:
    • ZeRO-3: Uses trainer.model_wrapped._zero3_consolidated_16bit_state_dict() to gather all parameters efficiently, then passes the full state dict to model.save_pretrained(). PEFT internally filters to LoRA-only keys.
    • Non-ZeRO-3: Uses get_peft_state_maybe_zero_3() to extract only LoRA parameters, then passes this filtered state dict to model.save_pretrained().
    • Both paths gate the file write on local_rank == 0 to prevent concurrent writes.

The output consists of adapter_model.bin (serialized LoRA weights) and adapter_config.json (LoRA hyperparameters) in the specified output directory.

Usage

Use this function when saving LoRA adapter weights after training in FastChat, especially when training may use DeepSpeed ZeRO-3 and parameters need gathering.

Code Reference

Source Location

  • Repository: FastChat
  • File: fastchat/train/train_lora.py (lines 68-75, maybe_zero_3 function)
  • File: fastchat/train/train_lora.py (lines 79-101, get_peft_state_maybe_zero_3 function)
  • File: fastchat/train/train_lora.py (lines 201-218, save logic)

Signature

# maybe_zero_3 helper (lines 68-75)
def maybe_zero_3(param):
    if hasattr(param, "ds_id"):
        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


# get_peft_state_maybe_zero_3 (lines 79-101)
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
    return to_return


# Save logic (lines 201-218)
if deepspeed.is_deepspeed_zero3_enabled():
    state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    if training_args.local_rank == 0:
        state_dict = state_dict_zero3
else:
    state_dict = get_peft_state_maybe_zero_3(
        model.named_parameters(), lora_args.lora_bias
    )

if training_args.local_rank == 0:
    model.save_pretrained(training_args.output_dir, state_dict=state_dict)

Import

from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import deepspeed

# For reuse in other scripts:
from fastchat.train.train_lora import get_peft_state_maybe_zero_3

I/O Contract

Inputs (get_peft_state_maybe_zero_3)

Name Type Required Description
named_params Iterator[Tuple[str, Tensor]] Yes Output of model.named_parameters(), yielding (name, parameter) pairs
bias str Yes Bias saving strategy: "none", "all", or "lora_only"

Inputs (maybe_zero_3)

Name Type Required Description
param torch.Tensor or DeepSpeedParameter Yes A model parameter, potentially partitioned under ZeRO-3

Outputs

Name Type Description
state_dict (from get_peft_state_maybe_zero_3) dict[str, Tensor] Dictionary of LoRA parameter names to their CPU-resident, fully-gathered tensor values
param (from maybe_zero_3) torch.Tensor CPU-resident, detached clone of the (possibly gathered) parameter
adapter_model.bin file Serialized LoRA adapter weights saved to output_dir
adapter_config.json file LoRA configuration JSON saved to output_dir

Usage Examples

Extracting LoRA State Dict (Non-ZeRO-3)

from fastchat.train.train_lora import get_peft_state_maybe_zero_3

# After training, extract only LoRA parameters
state_dict = get_peft_state_maybe_zero_3(
    model.named_parameters(), bias="none"
)

# Save adapter weights
model.save_pretrained("output_lora/", state_dict=state_dict)
# Creates: output_lora/adapter_model.bin, output_lora/adapter_config.json

Saving with ZeRO-3 Consolidated State Dict

from transformers import deepspeed

if deepspeed.is_deepspeed_zero3_enabled():
    # Efficient bulk all-gather of all parameters
    state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    if training_args.local_rank == 0:
        # PEFT's save_pretrained automatically filters to LoRA keys
        model.save_pretrained(training_args.output_dir, state_dict=state_dict)

Gathering a Single ZeRO-3 Parameter

from fastchat.train.train_lora import maybe_zero_3

# Gather a potentially partitioned parameter to CPU
for name, param in model.named_parameters():
    if "lora_A" in name:
        full_param = maybe_zero_3(param)
        print(f"{name}: shape={full_param.shape}, device={full_param.device}")

Post-Training Merge with apply_lora

python3 -m fastchat.model.apply_lora \
    --base-model-path meta-llama/Llama-2-7b-hf \
    --target-model-path output_merged/ \
    --lora-path output_lora/

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment