Implementation:Deepspeedai DeepSpeed UlyssesSP Forward
Overview
Concrete tool for executing sequence-parallel forward passes with all-to-all communication and tiled computation provided by the DeepSpeed library.
Description
UlyssesSPAttentionHF.forward() wraps the original attention function with all-to-all operations: it scatters tokens across heads before attention and gathers heads back after attention. sequence_tiled_compute() applies a function to sequence tiles for memory-efficient processing. Both support custom autograd for correct gradient computation.
The forward method handles:
- Eval bypass: If
disable_in_eval=Trueand the module is not training, the original attention function is called directly without any SP communication. - Variable sequence lengths: Dynamically updates shape expectations when
seq_length_is_variable=True. - Position ID gathering: Collects
position_idsfrom all SP ranks to provide the full positional context to the attention function. - GQA/MQA support: Adjusts
num_key_value_groupson the module when KV head replication is active. - Debug mode: An optional
skip_all_but_last_attention_debug_modeskips core attention for all but the last layer to speed up memory fitting tests.
The sequence_tiled_compute() function is a wrapper around the SequenceTiledCompute autograd function. It manages:
- Sharding: Splits
kwargs_to_shardtensors on the sequence dimension (dim=1) - Forward: Calls the provided function on each shard with
torch.no_grad() - Backward: Recomputes forward per shard and calls
torch.autograd.backward() - ZeRO integration: Defers gradient readiness via
ds_grad_is_readyuntil the last shard
Code Reference
- Repository: https://github.com/deepspeedai/DeepSpeed
- File:
deepspeed/runtime/sequence_parallel/ulysses_sp.py - Lines: L224-353 (
forward), L631-851 (sequence_tiled_compute+SequenceTiledComputeautograd)
Forward Signature
def forward(
self,
module: torch.nn.Module,
query: Tensor, # [bs, hc, sl, hs]
key: Tensor, # [bs, hc, sl, hs]
value: Tensor, # [bs, hc, sl, hs]
attention_mask: Tensor,
*args: Any,
**kwargs: Any,
) -> Tuple[Tensor, Optional[Tensor]]
sequence_tiled_compute Signature
def sequence_tiled_compute(
fn, # callable to invoke on each shard
seqlen: int, # total sequence length
shards: int, # number of tiles
kwargs_to_shard: dict, # tensors to shard on seq dim
kwargs_to_pass: dict, # kwargs passed unchanged
grad_requiring_tensor_key: str, # key of the grad-requiring tensor
compute_params=None, # model weights for ZeRO integration
output_unshard_dimension: int = 1,
output_reduction: str = "mean", # None | "mean" | "sum"
) -> Tensor
Import
from deepspeed.runtime.sequence_parallel.ulysses_sp import sequence_tiled_compute
I/O Contract
Inputs (forward)
| Parameter | Type | Required | Description |
|---|---|---|---|
| module | torch.nn.Module | Yes | The attention module being wrapped |
| query | Tensor | Yes | Query tensor with shape [bs, hc, sl_local, hs]
|
| key | Tensor | Yes | Key tensor with shape [bs, hc_kv, sl_local, hs]
|
| value | Tensor | Yes | Value tensor with shape [bs, hc_kv, sl_local, hs]
|
| attention_mask | Tensor | No | Attention mask (set to None internally by the Ulysses wrapper) |
Outputs (forward)
| Output | Type | Description |
|---|---|---|
| output | Tensor | Attention output redistributed back to local sequence chunks, shape [bs, sl_local, em]
|
| attn_weights | Tensor or None | Attention weights (typically None with flash attention) |
Inputs (sequence_tiled_compute)
| Parameter | Type | Required | Description |
|---|---|---|---|
| fn | callable | Yes | Function to call on sharded inputs |
| seqlen | int | Yes | Total sequence length of the tensors to shard |
| shards | int | Yes | Number of tiles to split into |
| kwargs_to_shard | dict | Yes | Dict of tensors to shard on sequence dimension |
| kwargs_to_pass | dict | Yes | Dict of kwargs passed unchanged to fn |
| grad_requiring_tensor_key | str | Yes | Key identifying which tensor in kwargs_to_shard requires gradients |
| compute_params | list | No | Model weights for ZeRO gradient coordination (default: None) |
| output_unshard_dimension | int | No | Dimension to concatenate outputs (default: 1) |
| output_reduction | str | No | Reduction to apply: None, "mean", or "sum" (default: "mean")
|
Outputs (sequence_tiled_compute)
| Output | Type | Description |
|---|---|---|
| result | Tensor | Aggregated result tensor after applying the specified reduction |
Usage Example
from deepspeed.runtime.sequence_parallel.ulysses_sp import sequence_tiled_compute
# Memory-efficient loss computation on long sequences
loss = sequence_tiled_compute(
fn=compute_loss_fn,
seqlen=local_seq_length,
shards=4, # split into 4 tiles
kwargs_to_shard={"logits": logits, "labels": labels},
kwargs_to_pass={"ignore_index": -100},
grad_requiring_tensor_key="logits",
output_reduction="mean",
)
engine.backward(loss)
engine.step()
# Memory-efficient MLP computation
hidden_states = sequence_tiled_compute(
fn=mlp_forward,
seqlen=hidden_states.shape[1],
shards=8,
kwargs_to_shard={"hidden_states": hidden_states},
kwargs_to_pass={},
grad_requiring_tensor_key="hidden_states",
compute_params=[mlp.down_proj.weight, mlp.gate_proj.weight, mlp.up_proj.weight],
output_unshard_dimension=1,
output_reduction=None, # concatenate, don't reduce
)
Related Pages
- Principle:Deepspeedai_DeepSpeed_Sequence_Parallel_Computation
- Heuristic:Deepspeedai_DeepSpeed_Sequence_Parallel_PyTorch_Version
Knowledge Sources
- https://github.com/deepspeedai/DeepSpeed
- https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/
- https://arxiv.org/abs/2309.14509
- https://arxiv.org/abs/2506.13996
Last updated: 2026-02-09 00:00 GMT