Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization Get Batch Logps

From Leeroopedia


Knowledge Sources
Domains Language_Modeling, Loss_Functions, NLP
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for extracting per-sequence log probabilities from autoregressive model logits provided by the direct-preference-optimization repository.

Description

The _get_batch_logps function takes raw model logits and corresponding label tokens, computes per-token log probabilities using log-softmax and gather, then sums (or averages) over non-masked positions to produce a single log probability score per sequence. It handles the standard autoregressive offset where logits at position t predict the token at position t+1.

Usage

Import this function when you need per-sequence log probabilities from model outputs. It is called internally by concatenated_forward to compute log probabilities for both chosen and rejected responses, which are then fed into preference_loss. Also used in the SFT branch of get_batch_metrics to compute the SFT loss.

Code Reference

Source Location

Signature

def _get_batch_logps(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    average_log_prob: bool = False,
) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with
                a value of -100 are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token.
                         Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities.
    """

Import

from trainers import _get_batch_logps

I/O Contract

Inputs

Name Type Required Description
logits torch.FloatTensor Yes Model output logits. Shape: (batch_size, sequence_length, vocab_size)
labels torch.LongTensor Yes Target token IDs; -100 marks masked positions. Shape: (batch_size, sequence_length)
average_log_prob bool No If True, average over tokens; if False, sum (default False)

Outputs

Name Type Description
log_probs torch.FloatTensor Per-sequence log probability (sum or average). Shape: (batch_size,)

Usage Examples

Computing Sequence Log Probabilities

from trainers import _get_batch_logps

# model_output.logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len) with -100 for prompt tokens
logits = model(input_ids, attention_mask=attention_mask).logits.to(torch.float32)
log_probs = _get_batch_logps(logits, labels, average_log_prob=False)
# log_probs shape: (batch_size,)

# For SFT loss: negate the log probabilities
sft_loss = -log_probs

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment