Implementation:Bentoml BentoML Framework Flax
| Knowledge Sources | |
|---|---|
| Domains | ML Framework, JAX, Model Serialization |
| Last Updated | 2026-02-13 15:00 GMT |
Overview
The bentoml.flax module provides BentoML integration for Google's Flax (Linen) neural network library built on JAX, enabling save, load, and serving of Flax models.
Description
This module implements the BentoML framework adapter for flax.linen.Module models. Flax models are saved as msgpack-serialized state dictionaries, with the module structure preserved as a custom object via cloudpickle.
Key implementation details:
- save_model(): Serializes the model state (parameters, batch statistics, etc.) using
flax.serialization.to_bytes()and stores thenn.Modulestructure as a custom object. - load_model(): Deserializes the state dictionary from msgpack and optionally initializes arrays as
jnp.ndarrayor places them on a specific device (TPU/GPU/CPU). Hides GPUs from TensorFlow to prevent memory conflicts with JAX. - get_runnable(): Creates a
FlaxRunnablethat supports TPU, GPU, and CPU resources. It uses JAX'smodel.apply()for inference with automatic input conversion from NumPy arrays and Pandas DataFrames to JAX arrays. Method results are cached for performance.
The module requires both flax and tensorflow (for XLA backend support).
Usage
Use this module to save and serve Flax/Linen models within BentoML. Suitable for JAX-based ML workflows involving image classification, NLP, or any task where Flax models are trained.
Code Reference
Source Location
- Repository: Bentoml_BentoML
- File: src/bentoml/_internal/frameworks/flax.py
- Lines: 1-323
Signature
def get(tag_like: str | Tag) -> bentoml.Model: ...
def load_model(bento_model: str | Tag | bentoml.Model,
init: bool = True,
device: str | XlaBackend = "cpu") -> tuple[nn.Module, dict[str, Any]]: ...
def save_model(name: Tag | str,
module: nn.Module,
state: dict[str, Any] | FrozenDict[str, Any] | struct.PyTreeNode,
*, signatures: ModelSignaturesType | None = None,
labels: dict[str, str] | None = None,
custom_objects: dict[str, Any] | None = None,
external_modules: List[ModuleType] | None = None,
metadata: dict[str, Any] | None = None) -> bentoml.Model: ...
def get_runnable(bento_model: bentoml.Model) -> type[bentoml.legacy.Runnable]: ...
Import
import bentoml
# Via public API
model = bentoml.flax.save_model(...)
module, state_dict = bentoml.flax.load_model(...)
I/O Contract
Inputs
save_model()
| Name | Type | Required | Description |
|---|---|---|---|
| name | Tag or str | Yes | Name/tag for the model in the BentoML store |
| module | flax.linen.Module | Yes | The Flax Linen module to save |
| state | dict, FrozenDict, or PyTreeNode | Yes | The model state dictionary (params, batch stats, etc.) |
| signatures | ModelSignaturesType or None | No | Inference method signatures (default: {"__call__": {"batchable": False}}) |
| labels | dict[str, str] or None | No | User-defined labels for model management |
| custom_objects | dict[str, Any] or None | No | Additional objects to serialize |
| external_modules | List[ModuleType] or None | No | Additional Python modules to save alongside |
| metadata | dict[str, Any] or None | No | Custom metadata for the model |
load_model()
| Name | Type | Required | Description |
|---|---|---|---|
| bento_model | str, Tag, or Model | Yes | Tag or Model instance to load from the store |
| init | bool | No | Whether to convert state to jnp.ndarray (default: True) |
| device | str or XlaBackend | No | Device to place parameters on when init=False (default: "cpu") |
Outputs
| Method | Return Type | Description |
|---|---|---|
| save_model() | bentoml.Model | A BentoML Model with the saved Flax module and state |
| load_model() | tuple[nn.Module, dict] | A tuple of (Flax module, state dictionary) |
| get() | bentoml.Model | The BentoML Model reference from the store |
| get_runnable() | type[Runnable] | A FlaxRunnable class supporting TPU/GPU/CPU inference |
Usage Examples
import bentoml
import jax
import jax.numpy as jnp
from flax import linen as nn
# Define a simple Flax model
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
# Train and save
rng = jax.random.PRNGKey(0)
model = CNN()
params = model.init(rng, jnp.ones((1, 784)))
tag = bentoml.flax.save_model("mnist", model, params)
# Load model back
net, state_dict = bentoml.flax.load_model("mnist:latest")
predict_fn = jax.jit(lambda x: net.apply({"params": state_dict["params"]}, x))
results = predict_fn(jnp.ones((1, 784)))