Implementation:Mlflow Mlflow Predict Function Interface
| Knowledge Sources | |
|---|---|
| Domains | ML_Ops, LLM_Evaluation |
| Last Updated | 2026-02-13 20:00 GMT |
Overview
Concrete tool for defining and wrapping user predict functions for evaluation provided by the MLflow library.
Description
The predict function interface consists of two parts: (1) a user-defined callable that implements the application under test, and (2) the internal convert_predict_fn wrapper that normalises the callable for use by the evaluation harness.
The user defines a function whose keyword parameters match the keys of the inputs dictionary in the evaluation dataset. The function can be synchronous or asynchronous and should return the model's output (any type). It does not need to accept self or conform to any base class.
Internally, convert_predict_fn performs three transformations:
- Async wrapping: If the function is an async coroutine, it is wrapped with
asyncio.run()using a configurable timeout (defaulting to 300 seconds, controlled byMLFLOW_GENAI_EVAL_ASYNC_TIMEOUT). - Trace validation: The harness calls the function once with a sample input using a no-op tracer to detect whether the function already emits traces. If no trace is emitted, the function is wrapped with
@mlflow.traceto ensure one trace per call. - Input unpacking: The final lambda wraps the function so that the inputs dictionary is unpacked into keyword arguments:
lambda request: predict_fn(**request).
Usage
Provide a predict function to mlflow.genai.evaluate() when the evaluation dataset contains only inputs (and optionally expectations) but no pre-computed outputs or trace columns. This is the standard workflow during iterative development and CI-based evaluation.
Code Reference
Source Location
- Repository: mlflow
- File:
mlflow/genai/utils/trace_utils.py - Lines: L503-529
Signature
# User-defined (no fixed signature -- parameter names must match inputs dict keys):
def predict_fn(**kwargs) -> Any: ...
# Internal wrapper:
def convert_predict_fn(
predict_fn: Callable[..., Any],
sample_input: Any,
) -> Callable[..., Any]:
"""
Check the predict_fn is callable and add trace decorator if it is not
already traced. If the predict_fn is an async function, wrap it to make
it synchronous.
"""
Import
# User-defined -- no import needed; just define a plain function.
# Internal (used automatically by mlflow.genai.evaluate):
from mlflow.genai.utils.trace_utils import convert_predict_fn
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| predict_fn | Callable[..., Any] |
Yes | User-defined function. Its keyword parameters must match the keys in the evaluation dataset's inputs column.
|
| sample_input | Any |
Yes (internal) | A sample inputs dictionary used to probe whether the function emits traces. |
User function parameter contract:
| Constraint | Description |
|---|---|
| Keyword arguments | Parameter names must match keys in the inputs dictionary of each evaluation row.
|
| Return type | Any value; the harness captures it as the row's outputs.
|
| Trace emission | Must emit exactly one MLflow trace per call. Auto-wrapped with @mlflow.trace if needed.
|
| Async support | Async functions (async def) are automatically wrapped with asyncio.run().
|
Outputs
| Name | Type | Description |
|---|---|---|
| wrapped_fn | Callable[..., Any] |
A synchronous, traced callable that accepts an inputs dictionary and unpacks it into keyword arguments for the user function. |
Usage Examples
Basic Usage
import mlflow.genai
import openai
# Define a predict function -- parameter names match inputs dict keys
def predict_fn(question: str) -> str:
response = openai.OpenAI().chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": question}],
)
return response.choices[0].message.content
# Evaluation dataset with inputs only (no pre-computed outputs)
data = [
{"inputs": {"question": "What is MLflow?"}},
{"inputs": {"question": "What is Spark?"}},
]
result = mlflow.genai.evaluate(
data=data,
predict_fn=predict_fn,
scorers=[...],
)
Async Predict Function
import mlflow.genai
import openai
async def async_predict(question: str) -> str:
client = openai.AsyncOpenAI()
response = await client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": question}],
)
return response.choices[0].message.content
# Async functions are automatically detected and wrapped
result = mlflow.genai.evaluate(
data=[{"inputs": {"question": "What is MLflow?"}}],
predict_fn=async_predict,
scorers=[...],
)
Multi-Parameter Predict Function
import mlflow.genai
def predict_fn(question: str, context: str) -> str:
prompt = f"Context: {context}\nQuestion: {question}"
# ... call LLM with prompt ...
return response
# Inputs dict keys must match parameter names
data = [
{
"inputs": {
"question": "What is MLflow?",
"context": "MLflow is an open-source platform for ML lifecycle.",
},
},
]
result = mlflow.genai.evaluate(
data=data,
predict_fn=predict_fn,
scorers=[...],
)