Implementation:Eric mitchell Direct preference optimization Get Batch Logps
| Knowledge Sources | |
|---|---|
| Domains | Language_Modeling, Loss_Functions, NLP |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for extracting per-sequence log probabilities from autoregressive model logits provided by the direct-preference-optimization repository.
Description
The _get_batch_logps function takes raw model logits and corresponding label tokens, computes per-token log probabilities using log-softmax and gather, then sums (or averages) over non-masked positions to produce a single log probability score per sequence. It handles the standard autoregressive offset where logits at position t predict the token at position t+1.
Usage
Import this function when you need per-sequence log probabilities from model outputs. It is called internally by concatenated_forward to compute log probabilities for both chosen and rejected responses, which are then fed into preference_loss. Also used in the SFT branch of get_batch_metrics to compute the SFT loss.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: trainers.py
- Lines: 90-115
Signature
def _get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with
a value of -100 are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token.
Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities.
"""
Import
from trainers import _get_batch_logps
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| logits | torch.FloatTensor | Yes | Model output logits. Shape: (batch_size, sequence_length, vocab_size) |
| labels | torch.LongTensor | Yes | Target token IDs; -100 marks masked positions. Shape: (batch_size, sequence_length) |
| average_log_prob | bool | No | If True, average over tokens; if False, sum (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| log_probs | torch.FloatTensor | Per-sequence log probability (sum or average). Shape: (batch_size,) |
Usage Examples
Computing Sequence Log Probabilities
from trainers import _get_batch_logps
# model_output.logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len) with -100 for prompt tokens
logits = model(input_ids, attention_mask=attention_mask).logits.to(torch.float32)
log_probs = _get_batch_logps(logits, labels, average_log_prob=False)
# log_probs shape: (batch_size,)
# For SFT loss: negate the log probabilities
sft_loss = -log_probs