Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL VocabParallelLogprobs

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Tensor_Parallelism
Last Updated 2026-02-07 20:00 GMT

Overview

Vocabulary parallelism utilities for computing log probabilities and target token rankings across sharded vocabulary partitions in tensor-parallel training.

Description

This module provides numerically stable, differentiable implementations for computing log probabilities and target rankings when the vocabulary dimension is sharded across tensor model parallel ranks. It contains three main components:

VocabUtility (lines 10-29): A static utility class with two methods for computing vocabulary index ranges. vocab_range_from_per_partition_vocab_size computes the [start, end) range for a given rank from the per-partition size, while vocab_range_from_global_vocab_size derives the per-partition size first using integer division.

_VocabParallelHelper (lines 32-68): A helper class that computes predicted logits for target tokens in a numerically stable manner. After subtracting the max logit (for numerical stability), it creates a target mask to identify which tokens belong to the current partition, gathers the predicted logits for those targets, and computes the exponential sum for log-sum-exp normalization.

_VocabParallelLogProbs (lines 71-109): A custom torch.autograd.Function that computes log probabilities across sharded vocabulary. The forward pass: (1) finds the global max logit via all-reduce, (2) computes partition-local predicted logits and exponential sums via the helper, (3) all-reduces both predicted logits and sum of exponentials, (4) computes log probabilities as predicted_logits - log(sum_exp_logits). The backward pass computes gradients efficiently using the saved softmax probabilities.

vocab_parallel_logprobs (lines 112-126): Public function that wraps _VocabParallelLogProbs.apply, returning per-token log probabilities.

vocab_parallel_target_rank (lines 129-171): Computes the ranking (0-indexed position) of target tokens within the full vocabulary by counting how many tokens have higher logits. Uses all-reduce to aggregate counts across partitions.

Usage

Use vocab_parallel_logprobs in training loops where the model output logits are sharded across tensor parallel ranks (e.g., in DPO or PPO training). Use vocab_parallel_target_rank for computing token-level ranking metrics. Both functions handle the distributed communication internally.

Code Reference

Source Location

Signature

class VocabUtility:
    @staticmethod
    def vocab_range_from_per_partition_vocab_size(
        per_partition_vocab_size: int, rank: int, world_size: int
    ) -> Sequence[int]: ...

    @staticmethod
    def vocab_range_from_global_vocab_size(
        global_vocab_size: int, rank: int, world_size: int
    ) -> Sequence[int]: ...

class _VocabParallelHelper:
    @staticmethod
    def calculate_predicted_logits(
        vocab_parallel_logits: torch.Tensor,
        target: torch.Tensor,
        logits_max: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...

class _VocabParallelLogProbs(torch.autograd.Function):
    @staticmethod
    def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: ...

def vocab_parallel_logprobs(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...

def vocab_parallel_target_rank(vocab_parallel_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...

Import

from mcore_adapter.parallel_functions.vocab_parallel import (
    vocab_parallel_logprobs,
    vocab_parallel_target_rank,
    VocabUtility,
)

I/O Contract

Inputs

Name Type Required Description
vocab_parallel_logits torch.Tensor Yes Logits split across tensor parallel ranks, shape [batch_size, sequence_length, vocab_size // num_parallel_ranks]
target torch.Tensor Yes Target token IDs, shape [batch_size, sequence_length], using global vocabulary indices

Outputs

Name Type Description
(vocab_parallel_logprobs) torch.Tensor Log probabilities of target tokens, shape [batch_size, sequence_length]
(vocab_parallel_target_rank) torch.Tensor Ranking index of target tokens across full vocabulary, shape [batch_size, sequence_length]

Usage Examples

from mcore_adapter.parallel_functions.vocab_parallel import (
    vocab_parallel_logprobs,
    vocab_parallel_target_rank,
)

# model outputs logits sharded across tensor parallel ranks
# logits shape: [batch_size, seq_len, vocab_size // tp_world_size]
logits = model(input_ids)

# Compute log probabilities across sharded vocab
# labels shape: [batch_size, seq_len] with global vocab indices
logprobs = vocab_parallel_logprobs(logits, labels)
# logprobs shape: [batch_size, seq_len]

# Compute target token ranking
rankings = vocab_parallel_target_rank(logits, labels)
# rankings shape: [batch_size, seq_len]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment