Implementation:Deepspeedai DeepSpeed Random LTD Binding
| Knowledge Sources | |
|---|---|
| Domains | Training, PyTorch_Bindings, Sparse_Training, Token_Selection |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
PyTorch bindings for Random Layer-wise Token Dropping (LTD) operations including token sorting, gathering, scattering, and attention mask manipulation for length-adaptive training.
Description
This module implements the infrastructure for Random LTD, a technique that dynamically drops tokens during training to reduce computation while maintaining model quality. The core operations include token_sort for comparison-free sorting of token indices based on importance scores, token_gather for collecting retained tokens into a compact representation, token_scatter for distributing processed tokens back to their original positions, and specialized mask operations (mask_gather_bert and mask_gather_gpt) that adapt attention masks to match the reduced token set. The implementation handles both batch-first and sequence-first tensor layouts, supports arbitrary stride patterns for memory layout flexibility, and processes token selections independently per layer enabling different retention rates at different depths. The sorting algorithm is deterministic and hardware-efficient, avoiding expensive comparison operations through clever indexing.
Usage
Use these operations when implementing length-adaptive training schemes where sequence length dynamically adjusts based on token importance. This is particularly valuable for long-sequence tasks where full attention is computationally prohibitive, or for efficient training where not all tokens need full model depth processing.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/random_ltd/pt_binding.cpp
Signature
// Comparison-free token sorting
torch::Tensor token_sort_(torch::Tensor& unsorted_token_ids,
int64_t original_tokens);
// Gather retained tokens into compact tensor
torch::Tensor token_gather(torch::Tensor& activations,
torch::Tensor& sorted_indices,
bool batch_first);
// Scatter processed tokens back to full sequence
torch::Tensor token_scatter_(torch::Tensor& all_activations,
torch::Tensor& layer_activations,
torch::Tensor& sorted_indices,
bool batch_first);
// Adapt BERT attention mask for retained tokens
torch::Tensor mask_gather_bert(torch::Tensor& dense_mask,
torch::Tensor& sorted_indices);
// Adapt GPT attention mask (causal) for retained tokens
torch::Tensor mask_gather_gpt(torch::Tensor dense_mask,
int truncated_seq_len);
Import
import deepspeed.ops.random_ltd as ltd_ops
I/O Contract
| Input | Type | Description |
|---|---|---|
| activations | torch.Tensor | Token embeddings [batch×seq×hidden] or [seq×batch×hidden] |
| sorted_indices | torch.Tensor | Token selection indices [layers×batch×retained] |
| unsorted_token_ids | torch.Tensor | Importance-based token IDs to sort |
| dense_mask | torch.Tensor | Full attention mask |
| batch_first | bool | Tensor layout flag |
| Output | Type | Description |
|---|---|---|
| gathered/scattered | torch.Tensor | Compacted or expanded activations |
| adapted_mask | torch.Tensor | Mask matching retained tokens |
Usage Examples
Basic Token Dropping Pipeline:
import torch
import deepspeed.ops.random_ltd as ltd
batch = 16
seq_len = 512
hidden = 768
num_layers = 12
retained_tokens = 384 # Keep 75% of tokens
# Compute token importance scores (e.g., via attention)
importance = torch.randn(num_layers, batch, seq_len, device='cuda')
# Create token indices based on importance
token_ids = torch.argsort(importance, dim=-1, descending=True)
# Keep top retained_tokens per layer
token_ids = token_ids[:, :, :retained_tokens].int()
# Sort tokens (in-place, comparison-free)
sorted_ids = ltd.token_sort_(token_ids.contiguous(), seq_len)
# Original activations
x = torch.randn(batch, seq_len, hidden, device='cuda')
# Gather only retained tokens
x_compact = ltd.token_gather(x, sorted_ids[0], batch_first=True)
print(f"Compacted shape: {x_compact.shape}") # [batch, retained, hidden]
# Process through layer...
# After processing, scatter back if needed
output_full = torch.zeros_like(x)
ltd.token_scatter_(output_full, x_compact, sorted_ids[0], batch_first=True)
BERT-Style Bidirectional Attention:
def bert_with_token_dropping(embeddings, attention_mask, model_layers,
token_importance, retained_per_layer):
"""
BERT forward pass with layer-wise token dropping
"""
batch, seq_len, hidden = embeddings.shape
num_layers = len(model_layers)
# Compute token selection
token_ids = torch.argsort(token_importance, dim=-1, descending=True)
token_ids = token_ids[:, :, :max(retained_per_layer)].int()
sorted_indices = ltd.token_sort_(token_ids.contiguous(), seq_len)
x = embeddings
for layer_idx, layer in enumerate(model_layers):
num_retained = retained_per_layer[layer_idx]
# Gather tokens for this layer
indices = sorted_indices[layer_idx, :, :num_retained]
x_layer = ltd.token_gather(x, indices, batch_first=True)
# Adapt attention mask
mask_layer = ltd.mask_gather_bert(
attention_mask.unsqueeze(1).unsqueeze(1), # [batch,1,1,seq]
sorted_indices[:layer_idx+1] # All layers up to current
)
# Process layer with reduced tokens
x_layer = layer(x_layer, mask_layer[layer_idx])
# Scatter back to full sequence (or keep compact for next layer)
ltd.token_scatter_(x, x_layer, indices, batch_first=True)
return x
GPT-Style Autoregressive with Token Dropping:
def gpt_with_dynamic_length(embeddings, model_layers,
token_importance, keep_ratios):
"""
GPT-style generation with progressive token dropping
"""
batch, seq_len, hidden = embeddings.shape
num_layers = len(model_layers)
# Compute retentions per layer
retained = [int(seq_len * r) for r in keep_ratios]
# Generate token indices
token_ids = torch.argsort(token_importance, dim=-1, descending=True)
token_ids = token_ids[:, :, :max(retained)].int()
sorted_indices = ltd.token_sort_(token_ids.contiguous(), seq_len)
x = embeddings
# Full causal mask for original sequence
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device='cuda') * float('-inf'),
diagonal=1
).unsqueeze(0).unsqueeze(0)
for layer_idx, layer in enumerate(model_layers):
num_kept = retained[layer_idx]
# Compact to retained tokens
indices = sorted_indices[layer_idx, :, :num_kept]
x_layer = ltd.token_gather(x, indices, batch_first=True)
# Adapt causal mask to retained tokens
mask_layer = ltd.mask_gather_gpt(causal_mask, num_kept)
# Layer computation
x_layer = layer(x_layer, mask_layer)
# Expand back (fill dropped positions with zeros or last value)
x_full = torch.zeros_like(x)
ltd.token_scatter_(x_full, x_layer, indices, batch_first=True)
x = x_full
return x
Sequence-First Layout Support:
# Handle [seq, batch, hidden] layout (common in some frameworks)
seq_len, batch, hidden = 512, 32, 768
# Activations in sequence-first format
x_seq_first = torch.randn(seq_len, batch, hidden, device='cuda')
# Token importance per layer
importance = torch.randn(12, batch, seq_len, device='cuda')
token_ids = torch.argsort(importance, dim=-1, descending=True)[:, :, :384].int()
sorted_indices = ltd.token_sort_(token_ids.contiguous(), seq_len)
# Gather with batch_first=False for seq-first layout
x_compact = ltd.token_gather(x_seq_first, sorted_indices[0],
batch_first=False)
print(f"Seq-first compact: {x_compact.shape}") # [384, batch, hidden]
# Process and scatter back
# ... layer computation ...
ltd.token_scatter_(x_seq_first, x_compact, sorted_indices[0],
batch_first=False)
Adaptive Training with Variable Retention:
class AdaptiveLengthTrainer:
def __init__(self, model, base_seq_len=512):
self.model = model
self.base_seq_len = base_seq_len
# Retention schedule: more tokens in early layers
self.retention_schedule = [
1.0, 0.9, 0.8, 0.75, 0.75, 0.7, # Layers 0-5
0.7, 0.65, 0.6, 0.55, 0.5, 0.5 # Layers 6-11
]
def forward(self, input_ids, attention_mask):
batch = input_ids.size(0)
# Get embeddings
embeddings = self.model.embed(input_ids)
# Compute token importance (e.g., attention entropy)
with torch.no_grad():
importance = self.compute_importance(embeddings)
# Generate sorted indices
retained = [int(self.base_seq_len * r)
for r in self.retention_schedule]
token_ids = torch.argsort(importance, dim=-1, descending=True)
token_ids = token_ids[:, :, :max(retained)].int()
sorted_indices = ltd.token_sort_(token_ids.contiguous(),
self.base_seq_len)
# Forward through layers with dropping
x = embeddings
for i, layer in enumerate(self.model.layers):
indices = sorted_indices[i, :, :retained[i]]
x_layer = ltd.token_gather(x, indices, batch_first=True)
# Compute with reduced tokens
x_layer = layer(x_layer)
# Restore full length for next layer
ltd.token_scatter_(x, x_layer, indices, batch_first=True)
return self.model.head(x)
def compute_importance(self, embeddings):
# Placeholder: compute per-token importance
return torch.rand(12, embeddings.size(0), embeddings.size(1),
device=embeddings.device)
Memory-Efficient Long Sequences:
# Process very long sequences by aggressively dropping tokens
def efficient_long_sequence(embeddings, model, max_seq=8192, target_seq=512):
"""
Process 8K tokens but reduce to 512 for most computation
"""
batch, seq_len, hidden = embeddings.shape
assert seq_len == max_seq
num_layers = len(model.layers)
# Dramatic reduction after initial layers
retention = [1.0, 0.8, 0.6] + [target_seq/max_seq] * (num_layers - 3)
retained = [int(max_seq * r) for r in retention]
# Compute importance on full sequence (first layer)
x = model.layers[0](embeddings) # Full computation
# Get importance scores
importance = x.abs().mean(dim=-1, keepdim=True) # [batch, seq, 1]
importance = importance.expand(num_layers, batch, seq_len)
token_ids = torch.argsort(importance, dim=-1, descending=True)
token_ids = token_ids[:, :, :target_seq].int()
sorted_indices = ltd.token_sort_(token_ids.contiguous(), max_seq)
# Remaining layers with token dropping
for i in range(1, num_layers):
indices = sorted_indices[i, :, :retained[i]]
x_layer = ltd.token_gather(x, indices, batch_first=True)
x_layer = model.layers[i](x_layer)
ltd.token_scatter_(x, x_layer, indices, batch_first=True)
return x
Related Pages
- Custom CUDA Layers - Underlying gather/scatter kernels
- Sparse Attention Utils - Related sparse computation
- Inference PT Binding - Attention mask handling