Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Random LTD Binding

From Leeroopedia
Revision as of 14:47, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Deepspeedai_DeepSpeed_Random_LTD_Binding.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Training, PyTorch_Bindings, Sparse_Training, Token_Selection
Last Updated 2026-02-09 00:00 GMT

Overview

PyTorch bindings for Random Layer-wise Token Dropping (LTD) operations including token sorting, gathering, scattering, and attention mask manipulation for length-adaptive training.

Description

This module implements the infrastructure for Random LTD, a technique that dynamically drops tokens during training to reduce computation while maintaining model quality. The core operations include token_sort for comparison-free sorting of token indices based on importance scores, token_gather for collecting retained tokens into a compact representation, token_scatter for distributing processed tokens back to their original positions, and specialized mask operations (mask_gather_bert and mask_gather_gpt) that adapt attention masks to match the reduced token set. The implementation handles both batch-first and sequence-first tensor layouts, supports arbitrary stride patterns for memory layout flexibility, and processes token selections independently per layer enabling different retention rates at different depths. The sorting algorithm is deterministic and hardware-efficient, avoiding expensive comparison operations through clever indexing.

Usage

Use these operations when implementing length-adaptive training schemes where sequence length dynamically adjusts based on token importance. This is particularly valuable for long-sequence tasks where full attention is computationally prohibitive, or for efficient training where not all tokens need full model depth processing.

Code Reference

Source Location

Signature

// Comparison-free token sorting
torch::Tensor token_sort_(torch::Tensor& unsorted_token_ids,
                         int64_t original_tokens);

// Gather retained tokens into compact tensor
torch::Tensor token_gather(torch::Tensor& activations,
                          torch::Tensor& sorted_indices,
                          bool batch_first);

// Scatter processed tokens back to full sequence
torch::Tensor token_scatter_(torch::Tensor& all_activations,
                            torch::Tensor& layer_activations,
                            torch::Tensor& sorted_indices,
                            bool batch_first);

// Adapt BERT attention mask for retained tokens
torch::Tensor mask_gather_bert(torch::Tensor& dense_mask,
                              torch::Tensor& sorted_indices);

// Adapt GPT attention mask (causal) for retained tokens
torch::Tensor mask_gather_gpt(torch::Tensor dense_mask,
                             int truncated_seq_len);

Import

import deepspeed.ops.random_ltd as ltd_ops

I/O Contract

Input Type Description
activations torch.Tensor Token embeddings [batch×seq×hidden] or [seq×batch×hidden]
sorted_indices torch.Tensor Token selection indices [layers×batch×retained]
unsorted_token_ids torch.Tensor Importance-based token IDs to sort
dense_mask torch.Tensor Full attention mask
batch_first bool Tensor layout flag
Output Type Description
gathered/scattered torch.Tensor Compacted or expanded activations
adapted_mask torch.Tensor Mask matching retained tokens

Usage Examples

Basic Token Dropping Pipeline:

import torch
import deepspeed.ops.random_ltd as ltd

batch = 16
seq_len = 512
hidden = 768
num_layers = 12
retained_tokens = 384  # Keep 75% of tokens

# Compute token importance scores (e.g., via attention)
importance = torch.randn(num_layers, batch, seq_len, device='cuda')

# Create token indices based on importance
token_ids = torch.argsort(importance, dim=-1, descending=True)
# Keep top retained_tokens per layer
token_ids = token_ids[:, :, :retained_tokens].int()

# Sort tokens (in-place, comparison-free)
sorted_ids = ltd.token_sort_(token_ids.contiguous(), seq_len)

# Original activations
x = torch.randn(batch, seq_len, hidden, device='cuda')

# Gather only retained tokens
x_compact = ltd.token_gather(x, sorted_ids[0], batch_first=True)
print(f"Compacted shape: {x_compact.shape}")  # [batch, retained, hidden]

# Process through layer...
# After processing, scatter back if needed
output_full = torch.zeros_like(x)
ltd.token_scatter_(output_full, x_compact, sorted_ids[0], batch_first=True)

BERT-Style Bidirectional Attention:

def bert_with_token_dropping(embeddings, attention_mask, model_layers,
                            token_importance, retained_per_layer):
    """
    BERT forward pass with layer-wise token dropping
    """
    batch, seq_len, hidden = embeddings.shape
    num_layers = len(model_layers)

    # Compute token selection
    token_ids = torch.argsort(token_importance, dim=-1, descending=True)
    token_ids = token_ids[:, :, :max(retained_per_layer)].int()
    sorted_indices = ltd.token_sort_(token_ids.contiguous(), seq_len)

    x = embeddings

    for layer_idx, layer in enumerate(model_layers):
        num_retained = retained_per_layer[layer_idx]

        # Gather tokens for this layer
        indices = sorted_indices[layer_idx, :, :num_retained]
        x_layer = ltd.token_gather(x, indices, batch_first=True)

        # Adapt attention mask
        mask_layer = ltd.mask_gather_bert(
            attention_mask.unsqueeze(1).unsqueeze(1),  # [batch,1,1,seq]
            sorted_indices[:layer_idx+1]  # All layers up to current
        )

        # Process layer with reduced tokens
        x_layer = layer(x_layer, mask_layer[layer_idx])

        # Scatter back to full sequence (or keep compact for next layer)
        ltd.token_scatter_(x, x_layer, indices, batch_first=True)

    return x

GPT-Style Autoregressive with Token Dropping:

def gpt_with_dynamic_length(embeddings, model_layers,
                           token_importance, keep_ratios):
    """
    GPT-style generation with progressive token dropping
    """
    batch, seq_len, hidden = embeddings.shape
    num_layers = len(model_layers)

    # Compute retentions per layer
    retained = [int(seq_len * r) for r in keep_ratios]

    # Generate token indices
    token_ids = torch.argsort(token_importance, dim=-1, descending=True)
    token_ids = token_ids[:, :, :max(retained)].int()
    sorted_indices = ltd.token_sort_(token_ids.contiguous(), seq_len)

    x = embeddings

    # Full causal mask for original sequence
    causal_mask = torch.triu(
        torch.ones(seq_len, seq_len, device='cuda') * float('-inf'),
        diagonal=1
    ).unsqueeze(0).unsqueeze(0)

    for layer_idx, layer in enumerate(model_layers):
        num_kept = retained[layer_idx]

        # Compact to retained tokens
        indices = sorted_indices[layer_idx, :, :num_kept]
        x_layer = ltd.token_gather(x, indices, batch_first=True)

        # Adapt causal mask to retained tokens
        mask_layer = ltd.mask_gather_gpt(causal_mask, num_kept)

        # Layer computation
        x_layer = layer(x_layer, mask_layer)

        # Expand back (fill dropped positions with zeros or last value)
        x_full = torch.zeros_like(x)
        ltd.token_scatter_(x_full, x_layer, indices, batch_first=True)
        x = x_full

    return x

Sequence-First Layout Support:

# Handle [seq, batch, hidden] layout (common in some frameworks)
seq_len, batch, hidden = 512, 32, 768

# Activations in sequence-first format
x_seq_first = torch.randn(seq_len, batch, hidden, device='cuda')

# Token importance per layer
importance = torch.randn(12, batch, seq_len, device='cuda')
token_ids = torch.argsort(importance, dim=-1, descending=True)[:, :, :384].int()
sorted_indices = ltd.token_sort_(token_ids.contiguous(), seq_len)

# Gather with batch_first=False for seq-first layout
x_compact = ltd.token_gather(x_seq_first, sorted_indices[0],
                             batch_first=False)
print(f"Seq-first compact: {x_compact.shape}")  # [384, batch, hidden]

# Process and scatter back
# ... layer computation ...
ltd.token_scatter_(x_seq_first, x_compact, sorted_indices[0],
                   batch_first=False)

Adaptive Training with Variable Retention:

class AdaptiveLengthTrainer:
    def __init__(self, model, base_seq_len=512):
        self.model = model
        self.base_seq_len = base_seq_len
        # Retention schedule: more tokens in early layers
        self.retention_schedule = [
            1.0, 0.9, 0.8, 0.75, 0.75, 0.7,  # Layers 0-5
            0.7, 0.65, 0.6, 0.55, 0.5, 0.5   # Layers 6-11
        ]

    def forward(self, input_ids, attention_mask):
        batch = input_ids.size(0)

        # Get embeddings
        embeddings = self.model.embed(input_ids)

        # Compute token importance (e.g., attention entropy)
        with torch.no_grad():
            importance = self.compute_importance(embeddings)

        # Generate sorted indices
        retained = [int(self.base_seq_len * r)
                   for r in self.retention_schedule]
        token_ids = torch.argsort(importance, dim=-1, descending=True)
        token_ids = token_ids[:, :, :max(retained)].int()
        sorted_indices = ltd.token_sort_(token_ids.contiguous(),
                                        self.base_seq_len)

        # Forward through layers with dropping
        x = embeddings
        for i, layer in enumerate(self.model.layers):
            indices = sorted_indices[i, :, :retained[i]]
            x_layer = ltd.token_gather(x, indices, batch_first=True)

            # Compute with reduced tokens
            x_layer = layer(x_layer)

            # Restore full length for next layer
            ltd.token_scatter_(x, x_layer, indices, batch_first=True)

        return self.model.head(x)

    def compute_importance(self, embeddings):
        # Placeholder: compute per-token importance
        return torch.rand(12, embeddings.size(0), embeddings.size(1),
                         device=embeddings.device)

Memory-Efficient Long Sequences:

# Process very long sequences by aggressively dropping tokens
def efficient_long_sequence(embeddings, model, max_seq=8192, target_seq=512):
    """
    Process 8K tokens but reduce to 512 for most computation
    """
    batch, seq_len, hidden = embeddings.shape
    assert seq_len == max_seq

    num_layers = len(model.layers)

    # Dramatic reduction after initial layers
    retention = [1.0, 0.8, 0.6] + [target_seq/max_seq] * (num_layers - 3)
    retained = [int(max_seq * r) for r in retention]

    # Compute importance on full sequence (first layer)
    x = model.layers[0](embeddings)  # Full computation

    # Get importance scores
    importance = x.abs().mean(dim=-1, keepdim=True)  # [batch, seq, 1]
    importance = importance.expand(num_layers, batch, seq_len)

    token_ids = torch.argsort(importance, dim=-1, descending=True)
    token_ids = token_ids[:, :, :target_seq].int()
    sorted_indices = ltd.token_sort_(token_ids.contiguous(), max_seq)

    # Remaining layers with token dropping
    for i in range(1, num_layers):
        indices = sorted_indices[i, :, :retained[i]]
        x_layer = ltd.token_gather(x, indices, batch_first=True)
        x_layer = model.layers[i](x_layer)
        ltd.token_scatter_(x, x_layer, indices, batch_first=True)

    return x

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment