Principle:Alibaba ROLL Vocabulary Parallel Computation
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Tensor_Parallelism |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Numerically stable computation of token-level log-probabilities when the vocabulary logit vector is partitioned across tensor-parallel ranks, using coordinated reductions with custom gradient functions.
Description
In tensor-parallel training, the output projection (language model head) produces logits that are sharded across GPUs along the vocabulary dimension. Each GPU holds logits for only vocabulary entries, where is the total vocabulary size and is the tensor-parallel degree. Computing log-probabilities requires the log-softmax function, which involves both a maximum (for numerical stability) and a sum of exponentials -- both of which require knowledge of the full logit vector.
The naive approach of gathering all logits to every rank before computing log-softmax would be prohibitively expensive in memory (materializing the full -dimensional vector) and communication. Instead, this principle performs the computation in-place on the sharded logits using three coordinated all-reduce operations:
- All-reduce max: Find the global maximum logit across all ranks for numerical stability.
- Local exp and sum: Each rank subtracts the global max, exponentiates its local partition, and computes a local partial sum.
- All-reduce sum of exponentials: Sum the partial sums to get the global normalization constant.
- All-reduce predicted logit: Sum the target token's logit (which is non-zero on exactly one rank).
- Compute log-prob:
The backward pass is implemented as a custom autograd function that computes the gradient of the log-softmax without materializing the full Jacobian, using the identity that the gradient of log-softmax with respect to logit is .
Usage
Use this principle when:
- Computing per-token log-probabilities for reinforcement learning objectives (e.g., policy gradient, DPO) where the model uses tensor-parallel output projection.
- You need gradients through the log-probability computation without gathering the full vocabulary logits.
- The vocabulary is large enough that gathering it to each rank would cause out-of-memory errors.
Theoretical Basis
Numerically stable log-softmax on sharded logits:
Let denote the logit partition on rank :
For target token :
Backward pass:
Given upstream gradient :
This is computed entirely locally on each rank without additional communication, because each rank already has and from the forward pass (saved for backward).
Memory efficiency:
The saved tensors for backward are (size ), the target mask, the sum , and the masked target index -- all of which are local to each rank. No full-vocabulary tensors are materialized.