Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Alibaba ROLL Vocabulary Parallel Computation

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Tensor_Parallelism
Last Updated 2026-02-07 20:00 GMT

Overview

Numerically stable computation of token-level log-probabilities when the vocabulary logit vector is partitioned across tensor-parallel ranks, using coordinated reductions with custom gradient functions.

Description

In tensor-parallel training, the output projection (language model head) produces logits that are sharded across GPUs along the vocabulary dimension. Each GPU holds logits for only V/T vocabulary entries, where V is the total vocabulary size and T is the tensor-parallel degree. Computing log-probabilities requires the log-softmax function, which involves both a maximum (for numerical stability) and a sum of exponentials -- both of which require knowledge of the full logit vector.

The naive approach of gathering all logits to every rank before computing log-softmax would be prohibitively expensive in memory (materializing the full V-dimensional vector) and communication. Instead, this principle performs the computation in-place on the sharded logits using three coordinated all-reduce operations:

  1. All-reduce max: Find the global maximum logit across all ranks for numerical stability.
  2. Local exp and sum: Each rank subtracts the global max, exponentiates its local partition, and computes a local partial sum.
  3. All-reduce sum of exponentials: Sum the partial sums to get the global normalization constant.
  4. All-reduce predicted logit: Sum the target token's logit (which is non-zero on exactly one rank).
  5. Compute log-prob: logp(y)=logit(y)log(exp(logits))

The backward pass is implemented as a custom autograd function that computes the gradient of the log-softmax without materializing the full Jacobian, using the identity that the gradient of log-softmax with respect to logit i is δiysoftmax(i).

Usage

Use this principle when:

  • Computing per-token log-probabilities for reinforcement learning objectives (e.g., policy gradient, DPO) where the model uses tensor-parallel output projection.
  • You need gradients through the log-probability computation without gathering the full vocabulary logits.
  • The vocabulary is large enough that gathering it to each rank would cause out-of-memory errors.

Theoretical Basis

Numerically stable log-softmax on sharded logits:

Let z(r) denote the logit partition on rank r:

m=maxr(maxizi(r))(all-reduce MAX)

S(r)=iexp(zi(r)m)

S=rS(r)(all-reduce SUM)

For target token y:

zy=rzy(r)𝟙[yvocab(r)](all-reduce SUM)

logp(y)=(zym)logS

Backward pass:

Given upstream gradient g=L/logp(y):

Lzi(r)=g(𝟙[i=ylocal]𝟙[yvocab(r)]exp(zi(r)m)S)

This is computed entirely locally on each rank without additional communication, because each rank already has S and m from the forward pass (saved for backward).

Memory efficiency:

The saved tensors for backward are exp(z(r)m) (size V/T), the target mask, the sum S, and the masked target index -- all of which are local to each rank. No full-vocabulary tensors are materialized.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment