Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals RemoteSequentialAutogradFunction

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Deep_Learning, Training
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for computing distributed forward and backward passes through remote transformer blocks, provided by Petals as a custom PyTorch autograd function.

Description

_RemoteSequentialAutogradFunction is a custom torch.autograd.Function that enables gradient computation through the distributed transformer block pipeline. It is the core mechanism that makes standard PyTorch training loops work with Petals' distributed architecture.

Key implementation details:

  • Batch splitting: If the batch exceeds MAX_TOKENS_IN_BATCH (1024), it is split into smaller sub-batches processed independently for memory efficiency
  • Sequential forward: sequential_forward() iterates through server spans, calling run_remote_forward() on each
  • Sequential backward: sequential_backward() reverses through the same server chain, calling run_remote_backward()
  • Routing mode: Uses "max_throughput" routing (not "min_latency") during training for better utilization
  • Fault tolerance: Retries with re-routing on server failures via RemoteSequenceManager

Usage

This function is called automatically when RemoteSequential.forward() is invoked during a training forward pass (when gradients are required). Users do not call it directly — they simply write standard PyTorch training code with loss.backward().

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/client/sequential_autograd.py (L223-277)
  • File: src/petals/client/sequential_autograd.py (L26-110, sequential_forward)
  • File: src/petals/client/sequential_autograd.py (L113-196, sequential_backward)
  • File: src/petals/client/remote_forward_backward.py (L67-149, run_remote_forward/backward)

Signature

MAX_TOKENS_IN_BATCH = 1024

class _RemoteSequentialAutogradFunction(torch.autograd.Function):
    """PyTorch autograd function for remote sequential processing"""

    @staticmethod
    def forward(
        ctx,
        inputs: torch.Tensor,
        prompts: torch.Tensor,
        sequence_manager: RemoteSequenceManager,
    ) -> torch.Tensor:
        """
        Forward pass through all remote transformer blocks.

        Args:
            ctx: Autograd context for saving tensors needed in backward
            inputs: Hidden state tensor [batch_size, seq_len, hidden_size]
            prompts: Prompt tuning embeddings or dummy tensor
            sequence_manager: Server routing manager
        Returns:
            Output hidden states after all remote blocks
        """

    @staticmethod
    def backward(
        ctx,
        grad_outputs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, None]:
        """
        Backward pass computing gradients through remote blocks.

        Args:
            ctx: Autograd context with saved tensors
            grad_outputs: Gradient of loss w.r.t. forward output
        Returns:
            Tuple of (grad_inputs, grad_prompts, None)
        """

Import

# Invoked automatically via RemoteSequential.forward() during training.
# Users don't import this directly.
# The function is used when:
model.train()
outputs = model(**batch)  # Triggers _RemoteSequentialAutogradFunction.forward
loss = outputs.loss
loss.backward()           # Triggers _RemoteSequentialAutogradFunction.backward

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Hidden states [batch_size, seq_len, hidden_size] after embedding
prompts torch.Tensor Yes Prompt embeddings or empty tensor if no prompt tuning
sequence_manager RemoteSequenceManager Yes Manages server discovery and routing for the forward/backward chain

Outputs

Name Type Description
forward() returns torch.Tensor Output hidden states [batch_size, seq_len, hidden_size] after all blocks
backward() returns Tuple (grad_inputs, grad_prompts, None) — gradients for inputs and prompts

Usage Examples

Standard Training Loop (Autograd is Transparent)

import torch
from petals.models.llama.model import DistributedLlamaForSequenceClassification
from transformers import AutoTokenizer

model_name = "enoch/llama-65b-hf"
model = DistributedLlamaForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Standard PyTorch training - autograd handles distributed backward automatically
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-3,
)

batch = tokenizer("This movie is great", return_tensors="pt", padding="max_length", max_length=128)
batch["labels"] = torch.tensor([1])

# Forward pass goes through _RemoteSequentialAutogradFunction.forward
outputs = model(**batch)
loss = outputs.loss

# Backward pass goes through _RemoteSequentialAutogradFunction.backward
loss.backward()
optimizer.step()
optimizer.zero_grad()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment