Principle:Deepspeedai DeepSpeed Sequence Parallel Computation
Overview
Executing forward and backward passes with sequence-parallel all-to-all communication in attention layers and optional tiled computation for memory-efficient MLP and loss computation.
Detailed Description
During training, the UlyssesSPAttentionHF.forward() method performs all-to-all communication to redistribute tensors between sequence-partitioned and head-partitioned forms around the attention computation. For MLP and loss computation on long sequences, sequence_tiled_compute() enables memory-efficient processing by breaking the sequence into smaller tiles, computing each tile separately, and combining results. This prevents OOM on very long sequences where the full MLP activation would exceed GPU memory.
Attention Forward Pass
The forward pass through the Ulysses SP attention layer follows this sequence:
- Input transformation: HuggingFace provides tensors in
[bs, hc, sl, hs]format; these are rearranged to[sl, bs, hc, hs]. - Variable sequence length handling: If
seq_length_is_variable=True, the local and global sequence lengths are dynamically computed from the current input shape. - Combine local sequences (all-to-all): Transforms from
[sl_l, bs, hc, hs]to[sl, bs, hc_l, hs]wheresl_l = S/P(local sequence) becomessl = S(global sequence), andhc(all heads) becomeshc_l = H/P(local heads). - Core attention: Standard attention is computed on the full sequence with local heads.
- Partition global sequence (all-to-all): Transforms back from
[sl, bs, em_l]to[sl_l, bs, em]. - Output transformation: Rearranges back to
[bs, sl_l, em]format for HuggingFace.
Additionally, position_ids are gathered across all SP ranks before attention to ensure correct positional encoding on the full sequence.
Tiled Computation
For operations that would consume too much memory on long sequences (MLP layers, loss computation), tiled computation breaks the sequence into smaller tiles:
- Split the sequence dimension into T tiles of size
S/T - Compute the function
fon each tile independently - Aggregate results via concatenation, mean, or sum
This trades compute time for memory: the forward pass runs T times on smaller inputs, and the backward pass recomputes forward for each tile (since activations are not stored). Memory usage drops from O(S * D) to O(S/T * D) per tile.
Theoretical Basis
Ulysses All-to-All
Before attention:
[B, S/P, H, D] -> all-to-all -> [B, S, H/P, D]
After attention:
[B, S, H/P, D] -> all-to-all -> [B, S/P, H, D]
The all-to-all is implemented via _DimZeroAllToAll, a custom autograd function that performs the communication and supports gradient computation in the backward pass.
Tiled Computation
For a function f over sequence dimension:
output = reduce([f(x[:, i*step:(i+1)*step]) for i in range(T)])
Where reduce is one of:
- None: Concatenation along the sequence dimension (for MLP hidden states)
- "mean": Scalar mean of all tile outputs (for loss)
- "sum": Scalar sum of all tile outputs
Memory per tile: O(S/T * D) instead of O(S * D).
The backward pass uses recomputation: for each tile, the forward is re-executed with torch.enable_grad() and torch.autograd.backward() is called on the tile output. When used with DeepSpeed ZeRO, the ds_grad_is_ready flag on parameters ensures gradient accumulation is deferred until the last tile.
Reference
- DeepSpeed-Ulysses (https://arxiv.org/abs/2309.14509)
- Arctic Long Sequence Training (https://arxiv.org/abs/2506.13996)
Related Pages
Knowledge Sources
- https://github.com/deepspeedai/DeepSpeed
- https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/
- https://arxiv.org/abs/2309.14509
- https://arxiv.org/abs/2506.13996
Last updated: 2026-02-09 00:00 GMT