Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Datatrove InferenceRunner

From Leeroopedia
Knowledge Sources
Domains Inference, Synthetic_Data, Pipeline Orchestration
Last Updated 2026-02-14 00:00 GMT

Overview

Pipeline step that runs user-defined rollout functions on documents using inference servers, managing concurrency, caching, metrics, and result aggregation.

Description

Template:Code is a Template:Code subclass that orchestrates the complete inference workflow. It initializes the appropriate inference server based on Template:Code, converts synchronous document iterators to async streams, dispatches documents to rollout functions with controlled concurrency, and writes results to both an output writer and local checkpoints.

The runner is configured through two main objects:

  • InferenceConfig: A dataclass containing all server and processing parameters (server type, model, parallelism, concurrency limits, generation defaults).
  • RolloutFunction: A user-provided async callable that defines the application-specific inference logic.

Internally, the runner uses:

  • MetricsKeeper: Tracks token throughput and request counts with 5-minute windowed and lifetime rates
  • QueueSizesKeeper: Tracks current queue sizes for waiting and running requests
  • CheckpointManager: Saves processed documents in chunks for fault-tolerant recovery
  • RequestCache: SQLite-backed cache that stores individual request/response pairs to avoid redundant inference on restart

The synchronous Template:Code method (required by the PipelineStep interface) delegates to Template:Code via Template:Code, which manages the full async lifecycle.

Usage

Use InferenceRunner when:

  • Building a synthetic data generation pipeline step in datatrove
  • Processing documents through one or more LLM calls with configurable concurrency
  • Long-running jobs require checkpointing and request caching for fault tolerance

Code Reference

Source Location

  • Repository: huggingface/datatrove
  • InferenceConfig: src/datatrove/pipeline/inference/run_inference.py:L42-96
  • InferenceRunner: src/datatrove/pipeline/inference/run_inference.py:L101-628

Signature

@dataclass
class InferenceConfig:
    server_type: Literal["sglang", "vllm", "dummy", "custom", "endpoint"]
    model_name_or_path: str
    model_max_context: int = 8192
    use_chat: bool = True
    endpoint_url: str | None = None
    api_key: str | None = None
    metric_interval: int = 120
    tp: int = 1
    dp: int = 1
    pp: int = 1
    default_generation_params: dict = field(default_factory=dict)
    rollouts_per_document: int = 1
    max_concurrent_generations: int = 500
    max_concurrent_documents: int | None = None
    request_timeout: float | None = None
    model_kwargs: dict | None = None
    server_log_folder: str | None = None
    master_port: int = 9810


class InferenceRunner(PipelineStep):
    def __init__(
        self,
        rollout_fn: RolloutFunction,
        config: InferenceConfig,
        output_writer: DiskWriter,
        shared_context: (dict | Callable[[], dict] | ContextManager[dict] | None) = None,
        checkpoints_local_dir: str | None = None,
        records_per_chunk: int = 6000,
        metadata_key: str = "rollout_results",
    ):
        ...

    def run(
        self,
        data: Iterable[Document],
        rank: int = 0,
        world_size: int = 1,
    ) -> None:
        ...

    async def run_async(
        self,
        data_gen: Iterable[Document],
        rank: int = 0,
    ) -> None:
        ...

Import

from datatrove.pipeline.inference.run_inference import InferenceRunner, InferenceConfig
from datatrove.pipeline.inference.types import InferenceResult, RolloutFunction

I/O Contract

Inputs

Name Type Required Description
rollout_fn RolloutFunction Yes Async callable that processes a document using the generate callback. Signature: Template:Code
config InferenceConfig Yes Server type, model, parallelism, concurrency, and generation parameters
output_writer DiskWriter Yes Writer for saving processed documents (e.g., ParquetWriter, JsonlWriter)
shared_context dict / Callable / ContextManager / None No Shared context passed as keyword arguments to the rollout function
checkpoints_local_dir str / None No Local directory for checkpoint files; enables fault-tolerant recovery
records_per_chunk int No Number of documents per checkpoint chunk (default: 6000)
metadata_key str No Key for storing rollout results in document metadata (default: "rollout_results")

Outputs

Name Type Description
Documents with rollout_results Document objects Input documents with Template:Code populated with a list of rollout results (InferenceResult or JSON-serializable values)
Checkpoint files JSONL files Local JSONL checkpoint files (when checkpoints_local_dir is set) containing processed documents in chunks
Metrics logs Log output Periodic metrics reports showing token throughput, request rates, and queue sizes

Usage Examples

Example: Basic synthetic data generation

from datatrove.pipeline.inference.run_inference import InferenceRunner, InferenceConfig
from datatrove.pipeline.inference.types import InferenceResult
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.data import Document


async def summarize_rollout(doc: Document, generate) -> InferenceResult:
    """Rollout function that summarizes each document."""
    result = await generate({
        "messages": [
            {"role": "system", "content": "Summarize the following text concisely."},
            {"role": "user", "content": doc.text},
        ],
        "max_tokens": 512,
        "temperature": 0.3,
    })
    return result


config = InferenceConfig(
    server_type="vllm",
    model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
    model_max_context=8192,
    tp=2,
    max_concurrent_generations=200,
    default_generation_params={"temperature": 0.7},
)

runner = InferenceRunner(
    rollout_fn=summarize_rollout,
    config=config,
    output_writer=JsonlWriter("output/summaries"),
    checkpoints_local_dir="/tmp/checkpoints",
)

Example: Multi-rollout with shared context

from datatrove.pipeline.inference.run_inference import InferenceRunner, InferenceConfig
from datatrove.pipeline.writers.parquet import ParquetWriter


async def classify_rollout(doc, generate, categories=None):
    """Classify document with different random seeds per rollout."""
    result = await generate({
        "messages": [
            {"role": "system", "content": f"Classify into: {categories}"},
            {"role": "user", "content": doc.text},
        ],
        "max_tokens": 64,
    })
    return {"label": result.text.strip(), "finish_reason": result.finish_reason}


config = InferenceConfig(
    server_type="endpoint",
    model_name_or_path="meta-llama/Llama-3.1-70B-Instruct",
    endpoint_url="https://api-inference.huggingface.co/v1",
    api_key="hf_...",
    rollouts_per_document=3,  # 3 independent classifications per document
    max_concurrent_generations=100,
)

runner = InferenceRunner(
    rollout_fn=classify_rollout,
    config=config,
    output_writer=ParquetWriter("output/classifications"),
    shared_context={"categories": "science, technology, arts, sports"},
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment