Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Bentoml BentoML Framework Flax

From Leeroopedia
Knowledge Sources
Domains ML Framework, JAX, Model Serialization
Last Updated 2026-02-13 15:00 GMT

Overview

The bentoml.flax module provides BentoML integration for Google's Flax (Linen) neural network library built on JAX, enabling save, load, and serving of Flax models.

Description

This module implements the BentoML framework adapter for flax.linen.Module models. Flax models are saved as msgpack-serialized state dictionaries, with the module structure preserved as a custom object via cloudpickle.

Key implementation details:

  • save_model(): Serializes the model state (parameters, batch statistics, etc.) using flax.serialization.to_bytes() and stores the nn.Module structure as a custom object.
  • load_model(): Deserializes the state dictionary from msgpack and optionally initializes arrays as jnp.ndarray or places them on a specific device (TPU/GPU/CPU). Hides GPUs from TensorFlow to prevent memory conflicts with JAX.
  • get_runnable(): Creates a FlaxRunnable that supports TPU, GPU, and CPU resources. It uses JAX's model.apply() for inference with automatic input conversion from NumPy arrays and Pandas DataFrames to JAX arrays. Method results are cached for performance.

The module requires both flax and tensorflow (for XLA backend support).

Usage

Use this module to save and serve Flax/Linen models within BentoML. Suitable for JAX-based ML workflows involving image classification, NLP, or any task where Flax models are trained.

Code Reference

Source Location

Signature

def get(tag_like: str | Tag) -> bentoml.Model: ...

def load_model(bento_model: str | Tag | bentoml.Model,
               init: bool = True,
               device: str | XlaBackend = "cpu") -> tuple[nn.Module, dict[str, Any]]: ...

def save_model(name: Tag | str,
               module: nn.Module,
               state: dict[str, Any] | FrozenDict[str, Any] | struct.PyTreeNode,
               *, signatures: ModelSignaturesType | None = None,
               labels: dict[str, str] | None = None,
               custom_objects: dict[str, Any] | None = None,
               external_modules: List[ModuleType] | None = None,
               metadata: dict[str, Any] | None = None) -> bentoml.Model: ...

def get_runnable(bento_model: bentoml.Model) -> type[bentoml.legacy.Runnable]: ...

Import

import bentoml

# Via public API
model = bentoml.flax.save_model(...)
module, state_dict = bentoml.flax.load_model(...)

I/O Contract

Inputs

save_model()

Name Type Required Description
name Tag or str Yes Name/tag for the model in the BentoML store
module flax.linen.Module Yes The Flax Linen module to save
state dict, FrozenDict, or PyTreeNode Yes The model state dictionary (params, batch stats, etc.)
signatures ModelSignaturesType or None No Inference method signatures (default: {"__call__": {"batchable": False}})
labels dict[str, str] or None No User-defined labels for model management
custom_objects dict[str, Any] or None No Additional objects to serialize
external_modules List[ModuleType] or None No Additional Python modules to save alongside
metadata dict[str, Any] or None No Custom metadata for the model

load_model()

Name Type Required Description
bento_model str, Tag, or Model Yes Tag or Model instance to load from the store
init bool No Whether to convert state to jnp.ndarray (default: True)
device str or XlaBackend No Device to place parameters on when init=False (default: "cpu")

Outputs

Method Return Type Description
save_model() bentoml.Model A BentoML Model with the saved Flax module and state
load_model() tuple[nn.Module, dict] A tuple of (Flax module, state dictionary)
get() bentoml.Model The BentoML Model reference from the store
get_runnable() type[Runnable] A FlaxRunnable class supporting TPU/GPU/CPU inference

Usage Examples

import bentoml
import jax
import jax.numpy as jnp
from flax import linen as nn

# Define a simple Flax model
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

# Train and save
rng = jax.random.PRNGKey(0)
model = CNN()
params = model.init(rng, jnp.ones((1, 784)))
tag = bentoml.flax.save_model("mnist", model, params)

# Load model back
net, state_dict = bentoml.flax.load_model("mnist:latest")
predict_fn = jax.jit(lambda x: net.apply({"params": state_dict["params"]}, x))
results = predict_fn(jnp.ones((1, 784)))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment