Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Deepspeedai DeepSpeed Sequence Parallel Computation

From Leeroopedia


Overview

Executing forward and backward passes with sequence-parallel all-to-all communication in attention layers and optional tiled computation for memory-efficient MLP and loss computation.

Detailed Description

During training, the UlyssesSPAttentionHF.forward() method performs all-to-all communication to redistribute tensors between sequence-partitioned and head-partitioned forms around the attention computation. For MLP and loss computation on long sequences, sequence_tiled_compute() enables memory-efficient processing by breaking the sequence into smaller tiles, computing each tile separately, and combining results. This prevents OOM on very long sequences where the full MLP activation would exceed GPU memory.

Attention Forward Pass

The forward pass through the Ulysses SP attention layer follows this sequence:

  1. Input transformation: HuggingFace provides tensors in [bs, hc, sl, hs] format; these are rearranged to [sl, bs, hc, hs].
  2. Variable sequence length handling: If seq_length_is_variable=True, the local and global sequence lengths are dynamically computed from the current input shape.
  3. Combine local sequences (all-to-all): Transforms from [sl_l, bs, hc, hs] to [sl, bs, hc_l, hs] where sl_l = S/P (local sequence) becomes sl = S (global sequence), and hc (all heads) becomes hc_l = H/P (local heads).
  4. Core attention: Standard attention is computed on the full sequence with local heads.
  5. Partition global sequence (all-to-all): Transforms back from [sl, bs, em_l] to [sl_l, bs, em].
  6. Output transformation: Rearranges back to [bs, sl_l, em] format for HuggingFace.

Additionally, position_ids are gathered across all SP ranks before attention to ensure correct positional encoding on the full sequence.

Tiled Computation

For operations that would consume too much memory on long sequences (MLP layers, loss computation), tiled computation breaks the sequence into smaller tiles:

  1. Split the sequence dimension into T tiles of size S/T
  2. Compute the function f on each tile independently
  3. Aggregate results via concatenation, mean, or sum

This trades compute time for memory: the forward pass runs T times on smaller inputs, and the backward pass recomputes forward for each tile (since activations are not stored). Memory usage drops from O(S * D) to O(S/T * D) per tile.

Theoretical Basis

Ulysses All-to-All

Before attention:

[B, S/P, H, D] -> all-to-all -> [B, S, H/P, D]

After attention:

[B, S, H/P, D] -> all-to-all -> [B, S/P, H, D]

The all-to-all is implemented via _DimZeroAllToAll, a custom autograd function that performs the communication and supports gradient computation in the backward pass.

Tiled Computation

For a function f over sequence dimension:

output = reduce([f(x[:, i*step:(i+1)*step]) for i in range(T)])

Where reduce is one of:

  • None: Concatenation along the sequence dimension (for MLP hidden states)
  • "mean": Scalar mean of all tile outputs (for loss)
  • "sum": Scalar sum of all tile outputs

Memory per tile: O(S/T * D) instead of O(S * D).

The backward pass uses recomputation: for each tile, the forward is re-executed with torch.enable_grad() and torch.autograd.backward() is called on the tile output. When used with DeepSpeed ZeRO, the ds_grad_is_ready flag on parameters ensures gradient accumulation is deferred until the last tile.

Reference

Related Pages

Knowledge Sources

Last updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment