Heuristic:Bigscience workshop Petals Batch Splitting Threshold
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Distributed_Computing |
| Last Updated | 2026-02-09 13:00 GMT |
Overview
Client-side batch splitting heuristic: split training batches so each sub-batch contains at most 1024 tokens (MAX_TOKENS_IN_BATCH) for efficient parallel remote processing.
Description
During distributed forward and backward passes, the client splits input batches into smaller sub-batches based on a token budget of 1024. Each sub-batch is sent to remote servers in parallel via `asyncio.gather`. This prevents overwhelming any single server with a very large batch while enabling concurrent processing across the swarm.
Usage
Applied automatically in `_RemoteSequentialAutogradFunction.forward()` and `.backward()`. The batch size is computed as `max(MAX_TOKENS_IN_BATCH // seq_length, 1)`. For sequences of length 512, this means 2 samples per sub-batch; for sequences of length 2048, each sample goes individually.
The Insight (Rule of Thumb)
- Action: The client automatically splits batches based on `MAX_TOKENS_IN_BATCH = 1024`.
- Value: `batch_size_per_request = max(1024 // sequence_length, 1)`
- Trade-off: Smaller sub-batches increase parallelism across servers but add more RPC overhead. Larger sub-batches are more compute-efficient but limit parallelism.
Reasoning
Splitting by token count (rather than sample count) ensures consistent memory pressure on remote servers regardless of sequence length. The value of 1024 balances between network overhead (too many small requests) and memory constraints (too few large requests). Sub-batches are processed in parallel via `asyncio.gather`, improving overall throughput.
Code Evidence
From `src/petals/client/sequential_autograd.py:23`:
MAX_TOKENS_IN_BATCH = 1024
From `src/petals/client/sequential_autograd.py:231-232`:
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)