Principle:Bigscience workshop Petals Distributed Autograd
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Deep_Learning, Training |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
A mechanism for computing gradients through remotely-hosted transformer blocks by implementing a custom PyTorch autograd function that transparently handles forward and backward passes over the network.
Description
Distributed Autograd enables training (fine-tuning) with models whose transformer blocks are hosted on remote servers. Standard PyTorch autograd cannot natively differentiate through network RPC calls, so Petals implements a custom torch.autograd.Function that:
Forward pass:
- Splits the batch into sub-batches if it exceeds MAX_TOKENS_IN_BATCH (1024 tokens)
- Sends each sub-batch through remote servers via run_remote_forward()
- Saves activations needed for the backward pass in the autograd context
Backward pass:
- Receives gradient of the loss with respect to the output
- Sends gradients back through the same servers via run_remote_backward()
- Servers compute local gradients through their transformer blocks and return input gradients
- The returned gradients propagate to local trainable parameters (prompt embeddings, classification head)
Fault tolerance: Both forward and backward support retries with automatic re-routing if a server fails mid-computation.
Usage
Use this principle when training any local parameters (prompt embeddings, classification heads) with a distributed Petals model. It is automatically invoked when calling model(**batch) followed by loss.backward() during a standard PyTorch training loop. The distributed autograd is transparent — users write normal PyTorch training code.
Theoretical Basis
Chain rule through distributed blocks:
Given a loss function L and a sequence of N transformer blocks distributed across servers:
Forward:
Backward:
Each server computes its local Jacobian-vector product and passes the result to the previous server.
Batch splitting for efficiency:
# Abstract distributed autograd algorithm
MAX_TOKENS_IN_BATCH = 1024
def forward(inputs, prompts, sequence_manager):
if inputs.shape[0] * inputs.shape[1] > MAX_TOKENS_IN_BATCH:
sub_batches = split(inputs, MAX_TOKENS_IN_BATCH)
outputs = [remote_forward(sub, prompts, servers) for sub in sub_batches]
return concatenate(outputs)
return remote_forward(inputs, prompts, servers)
def backward(grad_outputs):
# Reverse through the same server chain
grad_inputs = remote_backward(grad_outputs, servers)
return grad_inputs, grad_prompts, None