Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed UlyssesSP Forward

From Leeroopedia


Overview

Concrete tool for executing sequence-parallel forward passes with all-to-all communication and tiled computation provided by the DeepSpeed library.

Description

UlyssesSPAttentionHF.forward() wraps the original attention function with all-to-all operations: it scatters tokens across heads before attention and gathers heads back after attention. sequence_tiled_compute() applies a function to sequence tiles for memory-efficient processing. Both support custom autograd for correct gradient computation.

The forward method handles:

  • Eval bypass: If disable_in_eval=True and the module is not training, the original attention function is called directly without any SP communication.
  • Variable sequence lengths: Dynamically updates shape expectations when seq_length_is_variable=True.
  • Position ID gathering: Collects position_ids from all SP ranks to provide the full positional context to the attention function.
  • GQA/MQA support: Adjusts num_key_value_groups on the module when KV head replication is active.
  • Debug mode: An optional skip_all_but_last_attention_debug_mode skips core attention for all but the last layer to speed up memory fitting tests.

The sequence_tiled_compute() function is a wrapper around the SequenceTiledCompute autograd function. It manages:

  • Sharding: Splits kwargs_to_shard tensors on the sequence dimension (dim=1)
  • Forward: Calls the provided function on each shard with torch.no_grad()
  • Backward: Recomputes forward per shard and calls torch.autograd.backward()
  • ZeRO integration: Defers gradient readiness via ds_grad_is_ready until the last shard

Code Reference

  • Repository: https://github.com/deepspeedai/DeepSpeed
  • File: deepspeed/runtime/sequence_parallel/ulysses_sp.py
  • Lines: L224-353 (forward), L631-851 (sequence_tiled_compute + SequenceTiledCompute autograd)

Forward Signature

def forward(
    self,
    module: torch.nn.Module,
    query: Tensor,        # [bs, hc, sl, hs]
    key: Tensor,          # [bs, hc, sl, hs]
    value: Tensor,        # [bs, hc, sl, hs]
    attention_mask: Tensor,
    *args: Any,
    **kwargs: Any,
) -> Tuple[Tensor, Optional[Tensor]]

sequence_tiled_compute Signature

def sequence_tiled_compute(
    fn,                              # callable to invoke on each shard
    seqlen: int,                     # total sequence length
    shards: int,                     # number of tiles
    kwargs_to_shard: dict,           # tensors to shard on seq dim
    kwargs_to_pass: dict,            # kwargs passed unchanged
    grad_requiring_tensor_key: str,  # key of the grad-requiring tensor
    compute_params=None,             # model weights for ZeRO integration
    output_unshard_dimension: int = 1,
    output_reduction: str = "mean",  # None | "mean" | "sum"
) -> Tensor

Import

from deepspeed.runtime.sequence_parallel.ulysses_sp import sequence_tiled_compute

I/O Contract

Inputs (forward)

Parameter Type Required Description
module torch.nn.Module Yes The attention module being wrapped
query Tensor Yes Query tensor with shape [bs, hc, sl_local, hs]
key Tensor Yes Key tensor with shape [bs, hc_kv, sl_local, hs]
value Tensor Yes Value tensor with shape [bs, hc_kv, sl_local, hs]
attention_mask Tensor No Attention mask (set to None internally by the Ulysses wrapper)

Outputs (forward)

Output Type Description
output Tensor Attention output redistributed back to local sequence chunks, shape [bs, sl_local, em]
attn_weights Tensor or None Attention weights (typically None with flash attention)

Inputs (sequence_tiled_compute)

Parameter Type Required Description
fn callable Yes Function to call on sharded inputs
seqlen int Yes Total sequence length of the tensors to shard
shards int Yes Number of tiles to split into
kwargs_to_shard dict Yes Dict of tensors to shard on sequence dimension
kwargs_to_pass dict Yes Dict of kwargs passed unchanged to fn
grad_requiring_tensor_key str Yes Key identifying which tensor in kwargs_to_shard requires gradients
compute_params list No Model weights for ZeRO gradient coordination (default: None)
output_unshard_dimension int No Dimension to concatenate outputs (default: 1)
output_reduction str No Reduction to apply: None, "mean", or "sum" (default: "mean")

Outputs (sequence_tiled_compute)

Output Type Description
result Tensor Aggregated result tensor after applying the specified reduction

Usage Example

from deepspeed.runtime.sequence_parallel.ulysses_sp import sequence_tiled_compute

# Memory-efficient loss computation on long sequences
loss = sequence_tiled_compute(
    fn=compute_loss_fn,
    seqlen=local_seq_length,
    shards=4,  # split into 4 tiles
    kwargs_to_shard={"logits": logits, "labels": labels},
    kwargs_to_pass={"ignore_index": -100},
    grad_requiring_tensor_key="logits",
    output_reduction="mean",
)
engine.backward(loss)
engine.step()

# Memory-efficient MLP computation
hidden_states = sequence_tiled_compute(
    fn=mlp_forward,
    seqlen=hidden_states.shape[1],
    shards=8,
    kwargs_to_shard={"hidden_states": hidden_states},
    kwargs_to_pass={},
    grad_requiring_tensor_key="hidden_states",
    compute_params=[mlp.down_proj.weight, mlp.gate_proj.weight, mlp.up_proj.weight],
    output_unshard_dimension=1,
    output_reduction=None,  # concatenate, don't reduce
)

Related Pages

Knowledge Sources

Last updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment