Overview
Shared trainer utilities providing factory functions for custom optimizers, schedulers, loss functions, model loading helpers, and Ray distributed training infrastructure used across all training stages in LLaMA-Factory.
Description
trainer_utils.py is the central utility module for the LLaMA-Factory training subsystem. It provides building blocks consumed by every trainer implementation across all training stages (PT, SFT, RM, PPO, DPO, KTO). The module contains:
- Custom optimizer factories for GaLore, APOLLO, LoRA+, BAdam, Adam-mini, and Muon optimizers with support for both global and per-layer (layerwise) training modes.
- A DummyOptimizer class used as a placeholder for layerwise GaLore/APOLLO training where per-parameter optimizers are managed via gradient hooks.
- Custom scheduler support including warmup-stable-decay scheduling and layerwise scheduler hooks.
- Log probability computation (get_batch_logps) for DPO/KTO alignment training with optional length-difference-aware weighting.
- Custom loss functions (DFT and EAFT) that weight tokens by target probability or approximate entropy for improved training dynamics.
- Model loading helpers for reference models (create_ref_model) and reward models (create_reward_model) used in RLHF pipelines.
- HuggingFace Hub integration for model card creation and pushing.
- SwanLab callback creation for experiment tracking.
- Ray distributed training utilities for placement group management, worker configuration, and node IP sorting.
Usage
Use this module when implementing or extending training pipelines. The functions are called by the individual stage trainers (PT, SFT, DPO, PPO, KTO, RM) to configure optimizers, schedulers, loss functions, and auxiliary models. The Ray utilities are used by the tuner module when launching distributed training with Ray.
Code Reference
Source Location
Signature
class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None) -> None
def zero_grad(self, set_to_none: bool = True) -> None
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
def create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) -> None
def create_ref_model(model_args, finetuning_args, add_valuehead: bool = False) -> Optional[Union[PreTrainedModel, AutoModelForCausalLMWithValueHead]]
def create_reward_model(model, model_args, finetuning_args) -> Optional[AutoModelForCausalLMWithValueHead]
def create_custom_optimizer(model, training_args, finetuning_args) -> Optional[torch.optim.Optimizer]
def create_custom_scheduler(training_args, num_training_steps: int, optimizer=None) -> None
def get_batch_logps(logits, labels, label_pad_token_id=IGNORE_INDEX, ld_alpha=None) -> tuple[torch.Tensor, torch.Tensor]
def dft_loss_func(outputs, labels, num_items_in_batch=None)
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha: float = 1.0) -> torch.Tensor
def nested_detach(tensors, clone: bool = False)
def get_swanlab_callback(finetuning_args) -> TrainerCallback
def get_placement_group(num_workers: int) -> tuple[PlacementGroup, dict[str, int]]
def get_ray_remote_config_for_worker(placement_group, bundle_idx, rank, world_size, master_addr, master_port, env=None) -> dict[str, Any]
def get_ray_head_node_ip() -> str
def sort_placement_group_by_node_ip(placement_group, master_addr=None) -> list[int]
Import
from llamafactory.train.trainer_utils import create_custom_optimizer, create_custom_scheduler
from llamafactory.train.trainer_utils import create_ref_model, create_reward_model
from llamafactory.train.trainer_utils import get_batch_logps, dft_loss_func, eaft_loss_func
from llamafactory.train.trainer_utils import DummyOptimizer, nested_detach
from llamafactory.train.trainer_utils import get_swanlab_callback
from llamafactory.train.trainer_utils import get_placement_group, sort_placement_group_by_node_ip
I/O Contract
Inputs
create_custom_optimizer
| Name |
Type |
Required |
Description
|
| model |
PreTrainedModel |
Yes |
The model whose parameters will be optimized
|
| training_args |
TrainingArguments |
Yes |
HuggingFace training arguments including learning rate, weight decay, and optimizer type
|
| finetuning_args |
FinetuningArguments |
Yes |
LLaMA-Factory finetuning arguments specifying which optimizer to use (GaLore, APOLLO, LoRA+, BAdam, Adam-mini, Muon) and their hyperparameters
|
get_batch_logps
| Name |
Type |
Required |
Description
|
| logits |
torch.Tensor |
Yes |
Model output logits of shape (batch_size, seq_len, vocab_size)
|
| labels |
torch.Tensor |
Yes |
Target labels of shape (batch_size, seq_len) with IGNORE_INDEX for masked positions
|
| label_pad_token_id |
int |
No |
Token ID used for padding in labels (default: IGNORE_INDEX)
|
| ld_alpha |
float |
No |
Length-difference alpha for weighting front vs rear log probabilities; None disables LD weighting
|
create_ref_model
| Name |
Type |
Required |
Description
|
| model_args |
ModelArguments |
Yes |
Model configuration arguments
|
| finetuning_args |
FinetuningArguments |
Yes |
Finetuning arguments; if ref_model is set, loads from that path; if LoRA finetuning, returns None (uses implicit reference)
|
| add_valuehead |
bool |
No |
Whether to add a value head to the reference model (for PPO training)
|
Outputs
create_custom_optimizer
| Name |
Type |
Description
|
| optimizer |
Optional[torch.optim.Optimizer] |
The configured optimizer instance, or None if no custom optimizer is needed
|
get_batch_logps
| Name |
Type |
Description
|
| logps |
torch.Tensor |
Sum of per-token log probabilities for each example in the batch, shape (batch_size,)
|
| valid_length |
torch.Tensor |
Number of non-masked tokens per example, shape (batch_size,)
|
Usage Examples
# Creating a custom optimizer with GaLore
from llamafactory.train.trainer_utils import create_custom_optimizer
optimizer = create_custom_optimizer(model, training_args, finetuning_args)
if optimizer is not None:
trainer.optimizer = optimizer
# Computing batch log probabilities for DPO training
from llamafactory.train.trainer_utils import get_batch_logps
logps, valid_length = get_batch_logps(logits, labels)
chosen_logps = logps[:batch_size // 2]
rejected_logps = logps[batch_size // 2:]
# Creating a reference model for PPO/DPO
from llamafactory.train.trainer_utils import create_ref_model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=False)
# Using DFT loss function
from llamafactory.train.trainer_utils import dft_loss_func
loss = dft_loss_func(outputs, labels, num_items_in_batch=num_items)
Related Pages