Implementation:Huggingface Datatrove InferenceTypes
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning Inference, Type System |
| Last Updated | 2026-02-14 17:00 GMT |
Overview
The InferenceTypes module defines the core data types, exception classes, and protocol definitions used across the Datatrove inference pipeline.
Description
This module establishes the type system for the inference subsystem through a collection of dataclasses, exceptions, type aliases, and protocols. The InferenceResult dataclass encapsulates a successful inference response, containing the generated text, the finish reason, and token usage statistics. It serves as the standardized return type throughout the inference pipeline.
Two exception classes are defined: InferenceError wraps failures that occur during document processing, carrying the original document reference, the error details, and the request payload for debugging; ServerError captures unrecoverable server-level failures. Both exceptions provide structured error messages that include relevant context for troubleshooting.
The module also defines two critical type aliases and a protocol. GenerateFunction is a callable type that takes a payload dictionary and returns an awaitable InferenceResult. RolloutResult is a union type encompassing InferenceResult and common JSON-serializable types. The RolloutFunction protocol defines the interface for document processing functions that receive a document, a generate callback, and arbitrary keyword arguments from a shared context.
Usage
Use these types when implementing custom inference rollout functions, building new inference server integrations, or handling inference results in downstream pipeline steps. They provide the shared vocabulary that ensures type safety and consistency across the inference subsystem.
Code Reference
Source Location
- Repository: Huggingface_Datatrove
- File: src/datatrove/pipeline/inference/types.py
- Lines: 1-83
Signature
@dataclass
class InferenceResult:
text: str
finish_reason: str
usage: dict
class InferenceError(Exception):
def __init__(self, document: Document | None, error: str | Exception, payload: dict | None = None):
...
class ServerError(Exception):
def __init__(self, error: str | Exception):
...
GenerateFunction = Callable[[dict], Awaitable[InferenceResult]]
RolloutResult = InferenceResult | dict | list | str | float | int | bool | None
class RolloutFunction(Protocol):
def __call__(
self,
document: Document,
generate: GenerateFunction,
**kwargs: Any,
) -> Awaitable[RolloutResult]: ...
Import
from datatrove.pipeline.inference.types import InferenceResult, InferenceError, ServerError, GenerateFunction, RolloutFunction, RolloutResult
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| text | str | Yes (InferenceResult) | Generated text from the model |
| finish_reason | str | Yes (InferenceResult) | Reason why generation finished (e.g., "stop", "length") |
| usage | dict | Yes (InferenceResult) | Token usage statistics dictionary |
| document | Document or None | Yes (InferenceError) | The original document that failed processing |
| error | str or Exception | Yes (InferenceError/ServerError) | The underlying error that caused the failure |
| payload | dict or None | No (InferenceError) | The request payload that caused the failure |
Outputs
| Name | Type | Description |
|---|---|---|
| InferenceResult | dataclass | Encapsulated successful inference response with text, finish_reason, and usage |
| RolloutResult | Union type | InferenceResult or any JSON-serializable value returned by rollout functions |
Usage Examples
Basic Usage
from datatrove.pipeline.inference.types import InferenceResult, InferenceError
# Creating a successful inference result
result = InferenceResult(
text="The answer is 42.",
finish_reason="stop",
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
# Handling an inference error
try:
# ... inference logic ...
pass
except Exception as e:
raise InferenceError(document=doc, error=str(e), payload=request_payload)
Implementing a RolloutFunction
from datatrove.pipeline.inference.types import GenerateFunction, RolloutResult
from datatrove.data import Document
async def my_rollout(
document: Document,
generate: GenerateFunction,
**kwargs,
) -> RolloutResult:
payload = {"prompt": document.text, "max_tokens": 100}
result = await generate(payload)
return result