Principle:Huggingface Trl SFT Training Execution
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Gradient-based optimization loop with cross-entropy loss computation, token accuracy tracking, entropy monitoring, and optional activation offloading for memory-efficient supervised fine-tuning.
Description
The training execution phase is where the actual parameter updates happen. The SFTTrainer inherits the core training loop from transformers.Trainer but overrides the loss computation and training step to add SFT-specific metrics and activation offloading.
The execution proceeds as follows:
- Forward pass -- The model receives a batch of tokenized sequences (input_ids, attention_mask, labels) and produces logits. For language modeling, the loss is computed as token-level cross-entropy between shifted logits and shifted labels.
- Loss computation -- The
compute_loss()method delegates to the parent Trainer's loss computation but adds several metric tracking steps:- Token accuracy -- The fraction of non-padding tokens where the argmax of the logits matches the true label. This provides a more fine-grained signal than pure loss.
- Entropy -- Shannon entropy of the output distribution, computed in chunks for memory efficiency. This measures the model's confidence: decreasing entropy indicates the model is becoming more certain.
- Auxiliary loss -- For Mixture-of-Experts models with
output_router_logits=True, the load-balancing auxiliary loss is tracked.
- DFT loss (optional) -- When
loss_type="dft"(Dynamic Fine-Tuning), the loss function weights each token by the detached exponent of its log-probability, giving more weight to tokens the model already partially understands.
- Gradient accumulation -- The Trainer handles accumulating gradients over multiple micro-batches before performing an optimizer step, controlled by
gradient_accumulation_steps.
- Activation offloading -- When
activation_offloading=True, thetraining_step()override wraps the forward/backward pass in an activation offloading context manager that moves intermediate activations to CPU memory, reducing GPU memory pressure at the cost of increased data transfer.
- Gradient checkpointing -- Enabled by default in SFTConfig, this trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them.
Usage
Use this after the SFTTrainer is fully initialized. Call trainer.train() to start the optimization loop. Optionally pass resume_from_checkpoint to continue from a saved state.
Theoretical Basis
Cross-Entropy Loss for Language Modeling:
The standard next-token prediction loss is:
L = -(1/T) * sum_{t=1}^{T} log P(y_t | y_{<t})
where T is the number of non-masked tokens and P is the softmax over the model's logits.
When completion_only_loss=True, the sum is restricted to completion tokens (labels for prompt tokens are set to -100, which PyTorch's cross-entropy ignores).
Token Accuracy:
accuracy = (1/T) * sum_{t=1}^{T} 1[argmax(logits_t) == y_t]
where the sum is over non-masked tokens only.
Shannon Entropy:
H = -sum_{v} P(v) * log P(v)
computed per token position and averaged across the sequence. Computed in chunks to avoid materializing the full softmax for memory efficiency.
DFT Loss (Dynamic Fine-Tuning):
L_dft = -(1/T) * sum_{t} exp(log P(y_t)).detach() * log P(y_t)
= -(1/T) * sum_{t} P(y_t).detach() * log P(y_t)
This weights each token by the model's current probability for it (detached from the gradient graph), effectively giving more gradient signal to tokens the model partially knows rather than completely unknown tokens.
Gradient Checkpointing: Instead of storing all activations during the forward pass (memory cost O(L) for L layers), activations are recomputed during backpropagation, reducing memory to O(sqrt(L)) at the cost of one additional forward pass.
Activation Offloading: Moves intermediate activations from GPU to CPU RAM during the forward pass and retrieves them during the backward pass. This extends the effective memory budget but introduces PCIe transfer overhead.