Implementation:Alibaba ROLL VocabParallelLogprobs
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Tensor_Parallelism |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Vocabulary parallelism utilities for computing log probabilities and target token rankings across sharded vocabulary partitions in tensor-parallel training.
Description
This module provides numerically stable, differentiable implementations for computing log probabilities and target rankings when the vocabulary dimension is sharded across tensor model parallel ranks. It contains three main components:
VocabUtility (lines 10-29): A static utility class with two methods for computing vocabulary index ranges. vocab_range_from_per_partition_vocab_size computes the [start, end) range for a given rank from the per-partition size, while vocab_range_from_global_vocab_size derives the per-partition size first using integer division.
_VocabParallelHelper (lines 32-68): A helper class that computes predicted logits for target tokens in a numerically stable manner. After subtracting the max logit (for numerical stability), it creates a target mask to identify which tokens belong to the current partition, gathers the predicted logits for those targets, and computes the exponential sum for log-sum-exp normalization.
_VocabParallelLogProbs (lines 71-109): A custom torch.autograd.Function that computes log probabilities across sharded vocabulary. The forward pass: (1) finds the global max logit via all-reduce, (2) computes partition-local predicted logits and exponential sums via the helper, (3) all-reduces both predicted logits and sum of exponentials, (4) computes log probabilities as predicted_logits - log(sum_exp_logits). The backward pass computes gradients efficiently using the saved softmax probabilities.
vocab_parallel_logprobs (lines 112-126): Public function that wraps _VocabParallelLogProbs.apply, returning per-token log probabilities.
vocab_parallel_target_rank (lines 129-171): Computes the ranking (0-indexed position) of target tokens within the full vocabulary by counting how many tokens have higher logits. Uses all-reduce to aggregate counts across partitions.
Usage
Use vocab_parallel_logprobs in training loops where the model output logits are sharded across tensor parallel ranks (e.g., in DPO or PPO training). Use vocab_parallel_target_rank for computing token-level ranking metrics. Both functions handle the distributed communication internally.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/parallel_functions/vocab_parallel.py
- Lines: 1-171
Signature
class VocabUtility:
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank: int, world_size: int
) -> Sequence[int]: ...
@staticmethod
def vocab_range_from_global_vocab_size(
global_vocab_size: int, rank: int, world_size: int
) -> Sequence[int]: ...
class _VocabParallelHelper:
@staticmethod
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
class _VocabParallelLogProbs(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: ...
def vocab_parallel_logprobs(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...
def vocab_parallel_target_rank(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...
Import
from mcore_adapter.parallel_functions.vocab_parallel import (
vocab_parallel_logprobs,
vocab_parallel_target_rank,
VocabUtility,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| vocab_parallel_logits | torch.Tensor | Yes | Logits split across tensor parallel ranks, shape [batch_size, sequence_length, vocab_size // num_parallel_ranks] |
| target | torch.Tensor | Yes | Target token IDs, shape [batch_size, sequence_length], using global vocabulary indices |
Outputs
| Name | Type | Description |
|---|---|---|
| (vocab_parallel_logprobs) | torch.Tensor | Log probabilities of target tokens, shape [batch_size, sequence_length] |
| (vocab_parallel_target_rank) | torch.Tensor | Ranking index of target tokens across full vocabulary, shape [batch_size, sequence_length] |
Usage Examples
from mcore_adapter.parallel_functions.vocab_parallel import (
vocab_parallel_logprobs,
vocab_parallel_target_rank,
)
# model outputs logits sharded across tensor parallel ranks
# logits shape: [batch_size, seq_len, vocab_size // tp_world_size]
logits = model(input_ids)
# Compute log probabilities across sharded vocab
# labels shape: [batch_size, seq_len] with global vocab indices
logprobs = vocab_parallel_logprobs(logits, labels)
# logprobs shape: [batch_size, seq_len]
# Compute target token ranking
rankings = vocab_parallel_target_rank(logits, labels)
# rankings shape: [batch_size, seq_len]