Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Bigscience workshop Petals Distributed Autograd

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Deep_Learning, Training
Last Updated 2026-02-09 14:00 GMT

Overview

A mechanism for computing gradients through remotely-hosted transformer blocks by implementing a custom PyTorch autograd function that transparently handles forward and backward passes over the network.

Description

Distributed Autograd enables training (fine-tuning) with models whose transformer blocks are hosted on remote servers. Standard PyTorch autograd cannot natively differentiate through network RPC calls, so Petals implements a custom torch.autograd.Function that:

Forward pass:

  1. Splits the batch into sub-batches if it exceeds MAX_TOKENS_IN_BATCH (1024 tokens)
  2. Sends each sub-batch through remote servers via run_remote_forward()
  3. Saves activations needed for the backward pass in the autograd context

Backward pass:

  1. Receives gradient of the loss with respect to the output
  2. Sends gradients back through the same servers via run_remote_backward()
  3. Servers compute local gradients through their transformer blocks and return input gradients
  4. The returned gradients propagate to local trainable parameters (prompt embeddings, classification head)

Fault tolerance: Both forward and backward support retries with automatic re-routing if a server fails mid-computation.

Usage

Use this principle when training any local parameters (prompt embeddings, classification heads) with a distributed Petals model. It is automatically invoked when calling model(**batch) followed by loss.backward() during a standard PyTorch training loop. The distributed autograd is transparent — users write normal PyTorch training code.

Theoretical Basis

Chain rule through distributed blocks:

Given a loss function L and a sequence of N transformer blocks f1,f2,...,fN distributed across servers:

Forward: hN=fN(fN1(...f1(h0)))

Backward: Lh0=f1h0Tf2h1T...LhN

Each server computes its local Jacobian-vector product and passes the result to the previous server.

Batch splitting for efficiency:

# Abstract distributed autograd algorithm
MAX_TOKENS_IN_BATCH = 1024

def forward(inputs, prompts, sequence_manager):
    if inputs.shape[0] * inputs.shape[1] > MAX_TOKENS_IN_BATCH:
        sub_batches = split(inputs, MAX_TOKENS_IN_BATCH)
        outputs = [remote_forward(sub, prompts, servers) for sub in sub_batches]
        return concatenate(outputs)
    return remote_forward(inputs, prompts, servers)

def backward(grad_outputs):
    # Reverse through the same server chain
    grad_inputs = remote_backward(grad_outputs, servers)
    return grad_inputs, grad_prompts, None

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment