Implementation:Mlflow Mlflow Model Training Interface
| Knowledge Sources | |
|---|---|
| Domains | ML_Ops, Model_Management |
| Last Updated | 2026-02-13 20:00 GMT |
Overview
Concrete pattern for training machine learning models using any supported framework as a prerequisite for model logging and registry within the MLflow platform.
Description
Model training in the MLflow ecosystem is performed by user-authored code that invokes the fitting routines of an external machine learning library. MLflow does not prescribe a particular training API; instead it expects the practitioner to produce a fitted model object using whichever framework is appropriate for the problem. Popular choices include scikit-learn, PyTorch, TensorFlow, XGBoost, LightGBM, and Keras, among others.
The training code typically runs inside an MLflow run context (mlflow.start_run) so that hyperparameters, metrics, and the resulting model artifact can be captured together. Once the model object has been fitted, it is handed to a logging function such as mlflow.pyfunc.log_model or a flavour-specific logger (e.g., mlflow.sklearn.log_model) that serialises the model and records it as a run artifact.
Because MLflow treats training as external user code, there is no single source file or function signature to reference. The interface is defined by the contract: the training step must produce an in-memory model object that a downstream MLflow logging function can accept.
Usage
Use this pattern at the beginning of any MLflow-managed model lifecycle workflow. Wrap training code inside an mlflow.start_run() context manager, log hyperparameters with mlflow.log_params, record metrics with mlflow.log_metrics, and then pass the resulting model object to the appropriate logging function.
Code Reference
Source Location
- Repository: mlflow
- File: N/A (user-defined training code)
- Lines: N/A
Signature
# Typical training pattern (scikit-learn example)
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
Import
# Framework-specific imports (examples)
from sklearn.ensemble import RandomForestClassifier
import torch
import tensorflow as tf
import xgboost as xgb
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X_train | array-like or DataFrame | Yes | Feature matrix containing the training samples. |
| y_train | array-like or Series | Yes | Target vector containing the labels or values to predict. |
| hyperparameters | dict or keyword arguments | No | Configuration values such as learning rate, number of estimators, max depth, etc. |
Outputs
| Name | Type | Description |
|---|---|---|
| model | Fitted model object | A trained model object with a prediction method (e.g., .predict(), .forward()) ready for logging to MLflow. |
Usage Examples
Basic Usage
import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
with mlflow.start_run():
params = {"n_estimators": 100, "max_depth": 5, "random_state": 42}
mlflow.log_params(params)
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
accuracy = model.score(X_test, y_test)
mlflow.log_metric("accuracy", accuracy)
# Model object is now ready for logging
mlflow.sklearn.log_model(model, name="iris-classifier")