Implementation:Huggingface Datatrove InferenceRunner
| Knowledge Sources | |
|---|---|
| Domains | Inference, Synthetic_Data, Pipeline Orchestration |
| Last Updated | 2026-02-14 00:00 GMT |
Overview
Pipeline step that runs user-defined rollout functions on documents using inference servers, managing concurrency, caching, metrics, and result aggregation.
Description
Template:Code is a Template:Code subclass that orchestrates the complete inference workflow. It initializes the appropriate inference server based on Template:Code, converts synchronous document iterators to async streams, dispatches documents to rollout functions with controlled concurrency, and writes results to both an output writer and local checkpoints.
The runner is configured through two main objects:
- InferenceConfig: A dataclass containing all server and processing parameters (server type, model, parallelism, concurrency limits, generation defaults).
- RolloutFunction: A user-provided async callable that defines the application-specific inference logic.
Internally, the runner uses:
- MetricsKeeper: Tracks token throughput and request counts with 5-minute windowed and lifetime rates
- QueueSizesKeeper: Tracks current queue sizes for waiting and running requests
- CheckpointManager: Saves processed documents in chunks for fault-tolerant recovery
- RequestCache: SQLite-backed cache that stores individual request/response pairs to avoid redundant inference on restart
The synchronous Template:Code method (required by the PipelineStep interface) delegates to Template:Code via Template:Code, which manages the full async lifecycle.
Usage
Use InferenceRunner when:
- Building a synthetic data generation pipeline step in datatrove
- Processing documents through one or more LLM calls with configurable concurrency
- Long-running jobs require checkpointing and request caching for fault tolerance
Code Reference
Source Location
- Repository: huggingface/datatrove
- InferenceConfig: src/datatrove/pipeline/inference/run_inference.py:L42-96
- InferenceRunner: src/datatrove/pipeline/inference/run_inference.py:L101-628
Signature
@dataclass
class InferenceConfig:
server_type: Literal["sglang", "vllm", "dummy", "custom", "endpoint"]
model_name_or_path: str
model_max_context: int = 8192
use_chat: bool = True
endpoint_url: str | None = None
api_key: str | None = None
metric_interval: int = 120
tp: int = 1
dp: int = 1
pp: int = 1
default_generation_params: dict = field(default_factory=dict)
rollouts_per_document: int = 1
max_concurrent_generations: int = 500
max_concurrent_documents: int | None = None
request_timeout: float | None = None
model_kwargs: dict | None = None
server_log_folder: str | None = None
master_port: int = 9810
class InferenceRunner(PipelineStep):
def __init__(
self,
rollout_fn: RolloutFunction,
config: InferenceConfig,
output_writer: DiskWriter,
shared_context: (dict | Callable[[], dict] | ContextManager[dict] | None) = None,
checkpoints_local_dir: str | None = None,
records_per_chunk: int = 6000,
metadata_key: str = "rollout_results",
):
...
def run(
self,
data: Iterable[Document],
rank: int = 0,
world_size: int = 1,
) -> None:
...
async def run_async(
self,
data_gen: Iterable[Document],
rank: int = 0,
) -> None:
...
Import
from datatrove.pipeline.inference.run_inference import InferenceRunner, InferenceConfig
from datatrove.pipeline.inference.types import InferenceResult, RolloutFunction
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| rollout_fn | RolloutFunction | Yes | Async callable that processes a document using the generate callback. Signature: Template:Code |
| config | InferenceConfig | Yes | Server type, model, parallelism, concurrency, and generation parameters |
| output_writer | DiskWriter | Yes | Writer for saving processed documents (e.g., ParquetWriter, JsonlWriter) |
| shared_context | dict / Callable / ContextManager / None | No | Shared context passed as keyword arguments to the rollout function |
| checkpoints_local_dir | str / None | No | Local directory for checkpoint files; enables fault-tolerant recovery |
| records_per_chunk | int | No | Number of documents per checkpoint chunk (default: 6000) |
| metadata_key | str | No | Key for storing rollout results in document metadata (default: "rollout_results") |
Outputs
| Name | Type | Description |
|---|---|---|
| Documents with rollout_results | Document objects | Input documents with Template:Code populated with a list of rollout results (InferenceResult or JSON-serializable values) |
| Checkpoint files | JSONL files | Local JSONL checkpoint files (when checkpoints_local_dir is set) containing processed documents in chunks |
| Metrics logs | Log output | Periodic metrics reports showing token throughput, request rates, and queue sizes |
Usage Examples
Example: Basic synthetic data generation
from datatrove.pipeline.inference.run_inference import InferenceRunner, InferenceConfig
from datatrove.pipeline.inference.types import InferenceResult
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.data import Document
async def summarize_rollout(doc: Document, generate) -> InferenceResult:
"""Rollout function that summarizes each document."""
result = await generate({
"messages": [
{"role": "system", "content": "Summarize the following text concisely."},
{"role": "user", "content": doc.text},
],
"max_tokens": 512,
"temperature": 0.3,
})
return result
config = InferenceConfig(
server_type="vllm",
model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
model_max_context=8192,
tp=2,
max_concurrent_generations=200,
default_generation_params={"temperature": 0.7},
)
runner = InferenceRunner(
rollout_fn=summarize_rollout,
config=config,
output_writer=JsonlWriter("output/summaries"),
checkpoints_local_dir="/tmp/checkpoints",
)
from datatrove.pipeline.inference.run_inference import InferenceRunner, InferenceConfig
from datatrove.pipeline.writers.parquet import ParquetWriter
async def classify_rollout(doc, generate, categories=None):
"""Classify document with different random seeds per rollout."""
result = await generate({
"messages": [
{"role": "system", "content": f"Classify into: {categories}"},
{"role": "user", "content": doc.text},
],
"max_tokens": 64,
})
return {"label": result.text.strip(), "finish_reason": result.finish_reason}
config = InferenceConfig(
server_type="endpoint",
model_name_or_path="meta-llama/Llama-3.1-70B-Instruct",
endpoint_url="https://api-inference.huggingface.co/v1",
api_key="hf_...",
rollouts_per_document=3, # 3 independent classifications per document
max_concurrent_generations=100,
)
runner = InferenceRunner(
rollout_fn=classify_rollout,
config=config,
output_writer=ParquetWriter("output/classifications"),
shared_context={"categories": "science, technology, arts, sports"},
)