Implementation:Bigscience workshop Petals RemoteSequentialAutogradFunction
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Deep_Learning, Training |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for computing distributed forward and backward passes through remote transformer blocks, provided by Petals as a custom PyTorch autograd function.
Description
_RemoteSequentialAutogradFunction is a custom torch.autograd.Function that enables gradient computation through the distributed transformer block pipeline. It is the core mechanism that makes standard PyTorch training loops work with Petals' distributed architecture.
Key implementation details:
- Batch splitting: If the batch exceeds MAX_TOKENS_IN_BATCH (1024), it is split into smaller sub-batches processed independently for memory efficiency
- Sequential forward: sequential_forward() iterates through server spans, calling run_remote_forward() on each
- Sequential backward: sequential_backward() reverses through the same server chain, calling run_remote_backward()
- Routing mode: Uses "max_throughput" routing (not "min_latency") during training for better utilization
- Fault tolerance: Retries with re-routing on server failures via RemoteSequenceManager
Usage
This function is called automatically when RemoteSequential.forward() is invoked during a training forward pass (when gradients are required). Users do not call it directly — they simply write standard PyTorch training code with loss.backward().
Code Reference
Source Location
- Repository: petals
- File: src/petals/client/sequential_autograd.py (L223-277)
- File: src/petals/client/sequential_autograd.py (L26-110, sequential_forward)
- File: src/petals/client/sequential_autograd.py (L113-196, sequential_backward)
- File: src/petals/client/remote_forward_backward.py (L67-149, run_remote_forward/backward)
Signature
MAX_TOKENS_IN_BATCH = 1024
class _RemoteSequentialAutogradFunction(torch.autograd.Function):
"""PyTorch autograd function for remote sequential processing"""
@staticmethod
def forward(
ctx,
inputs: torch.Tensor,
prompts: torch.Tensor,
sequence_manager: RemoteSequenceManager,
) -> torch.Tensor:
"""
Forward pass through all remote transformer blocks.
Args:
ctx: Autograd context for saving tensors needed in backward
inputs: Hidden state tensor [batch_size, seq_len, hidden_size]
prompts: Prompt tuning embeddings or dummy tensor
sequence_manager: Server routing manager
Returns:
Output hidden states after all remote blocks
"""
@staticmethod
def backward(
ctx,
grad_outputs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, None]:
"""
Backward pass computing gradients through remote blocks.
Args:
ctx: Autograd context with saved tensors
grad_outputs: Gradient of loss w.r.t. forward output
Returns:
Tuple of (grad_inputs, grad_prompts, None)
"""
Import
# Invoked automatically via RemoteSequential.forward() during training.
# Users don't import this directly.
# The function is used when:
model.train()
outputs = model(**batch) # Triggers _RemoteSequentialAutogradFunction.forward
loss = outputs.loss
loss.backward() # Triggers _RemoteSequentialAutogradFunction.backward
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Hidden states [batch_size, seq_len, hidden_size] after embedding |
| prompts | torch.Tensor | Yes | Prompt embeddings or empty tensor if no prompt tuning |
| sequence_manager | RemoteSequenceManager | Yes | Manages server discovery and routing for the forward/backward chain |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() returns | torch.Tensor | Output hidden states [batch_size, seq_len, hidden_size] after all blocks |
| backward() returns | Tuple | (grad_inputs, grad_prompts, None) — gradients for inputs and prompts |
Usage Examples
Standard Training Loop (Autograd is Transparent)
import torch
from petals.models.llama.model import DistributedLlamaForSequenceClassification
from transformers import AutoTokenizer
model_name = "enoch/llama-65b-hf"
model = DistributedLlamaForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Standard PyTorch training - autograd handles distributed backward automatically
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=1e-3,
)
batch = tokenizer("This movie is great", return_tensors="pt", padding="max_length", max_length=128)
batch["labels"] = torch.tensor([1])
# Forward pass goes through _RemoteSequentialAutogradFunction.forward
outputs = model(**batch)
loss = outputs.loss
# Backward pass goes through _RemoteSequentialAutogradFunction.backward
loss.backward()
optimizer.step()
optimizer.zero_grad()