Principle:Eric mitchell Direct preference optimization Log Probability Extraction
| Knowledge Sources | |
|---|---|
| Domains | Language_Modeling, Probability_Theory, NLP |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
A technique for computing per-sequence log probabilities from autoregressive language model logits by gathering and summing per-token log probabilities over non-masked positions.
Description
Log probability extraction converts raw model logits into per-sequence log probability scores. In autoregressive language models, each token's probability is conditioned on all previous tokens. To get the total log probability of a sequence, we compute the log-softmax of the logits at each position, gather the log probability corresponding to the actual next token, and sum (or average) over all non-masked positions.
This operation is fundamental to DPO training because the loss function operates on sequence-level log probabilities rather than per-token losses. The extraction must correctly handle:
- The offset between logits and labels (logits at position t predict token at position t+1)
- Masked positions (label value -100) that should be excluded from the sum, typically prompt tokens
- The choice between summing or averaging log probabilities across tokens
Usage
Use this principle whenever you need to compute how likely a model considers a complete response sequence, given a prompt. This is essential for:
- Computing the DPO loss (requires log probs from both policy and reference models)
- Computing the SFT loss (negative log probability of target sequence)
- Evaluating model confidence on specific completions
Theoretical Basis
For an autoregressive model with parameters , the log probability of a response given a prompt is:
Each term is obtained by:
- Computing logits
- Applying log-softmax:
- Gathering the log probability for the actual token
- Summing over response tokens only (excluding prompt tokens via masking)
Pseudo-code:
# Abstract algorithm (NOT actual implementation)
per_token_logps = log_softmax(logits)[actual_tokens]
mask = (labels != IGNORE_INDEX)
sequence_logp = sum(per_token_logps * mask)