Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Bigscience workshop Petals Batch Splitting Threshold

From Leeroopedia



Knowledge Sources
Domains Optimization, Distributed_Computing
Last Updated 2026-02-09 13:00 GMT

Overview

Client-side batch splitting heuristic: split training batches so each sub-batch contains at most 1024 tokens (MAX_TOKENS_IN_BATCH) for efficient parallel remote processing.

Description

During distributed forward and backward passes, the client splits input batches into smaller sub-batches based on a token budget of 1024. Each sub-batch is sent to remote servers in parallel via `asyncio.gather`. This prevents overwhelming any single server with a very large batch while enabling concurrent processing across the swarm.

Usage

Applied automatically in `_RemoteSequentialAutogradFunction.forward()` and `.backward()`. The batch size is computed as `max(MAX_TOKENS_IN_BATCH // seq_length, 1)`. For sequences of length 512, this means 2 samples per sub-batch; for sequences of length 2048, each sample goes individually.

The Insight (Rule of Thumb)

  • Action: The client automatically splits batches based on `MAX_TOKENS_IN_BATCH = 1024`.
  • Value: `batch_size_per_request = max(1024 // sequence_length, 1)`
  • Trade-off: Smaller sub-batches increase parallelism across servers but add more RPC overhead. Larger sub-batches are more compute-efficient but limit parallelism.

Reasoning

Splitting by token count (rather than sample count) ensures consistent memory pressure on remote servers regardless of sequence length. The value of 1024 balances between network overhead (too many small requests) and memory constraints (too few large requests). Sub-batches are processed in parallel via `asyncio.gather`, improving overall throughput.

Code Evidence

From `src/petals/client/sequential_autograd.py:23`:

MAX_TOKENS_IN_BATCH = 1024

From `src/petals/client/sequential_autograd.py:231-232`:

batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment