Implementation:Bentoml BentoML Framework PyTorch
| Knowledge Sources | |
|---|---|
| Domains | ML Framework, PyTorch, Deep Learning, Model Serialization |
| Last Updated | 2026-02-13 15:00 GMT |
Overview
Provides the BentoML framework integration for PyTorch, enabling saving, loading, and serving torch.nn.Module models through the BentoML model store.
Description
This module implements the standard BentoML framework adapter for PyTorch. It uses torch.save() with cloudpickle as the pickle module for serialization and torch.load() for deserialization. The module registers PyTorchTensorContainer as a side effect.
save_model() validates the input is a torch.nn.Module (using LazyType for deferred type checking), creates a ModelContext with the torch version, and saves the full model (not just state_dict) to saved_model.pt using torch.save() with cloudpickle. The default signature is __call__ (non-batchable). It supports PartialKwargsModelOptions for method-level default arguments.
load_model() retrieves the model from the store and loads it with torch.load(), supporting an optional device_id parameter (defaults to "cpu") for placing the model on a specific device (e.g., "cuda:0"). Additional torch.load() keyword arguments are passed through via **torch_load_args.
get() retrieves the BentoML Model metadata for a given tag.
get_runnable() creates a PytorchModelRunnable using the shared PyTorch runnable utilities from common.pytorch. It uses partial_class() for class-level configuration and make_pytorch_runnable_method() for method generation, supporting partial kwargs per method signature.
Usage
Use this module to save PyTorch nn.Module models (including Torch Hub models) to the BentoML model store and serve them via BentoML services. Access via bentoml.pytorch.
Code Reference
Source Location
- Repository: Bentoml_BentoML
- File: src/bentoml/_internal/frameworks/pytorch.py
- Lines: 1-219
Signature
def get(tag_like: str | Tag) -> Model: ...
def load_model(
bentoml_model: str | Tag | Model,
device_id: t.Optional[str] = "cpu",
**torch_load_args: Any,
) -> torch.nn.Module: ...
def save_model(
name: Tag | str,
model: torch.nn.Module,
*,
signatures: ModelSignaturesType | None = None,
labels: t.Dict[str, str] | None = None,
custom_objects: t.Dict[str, t.Any] | None = None,
external_modules: t.List[ModuleType] | None = None,
metadata: t.Dict[str, t.Any] | None = None,
) -> bentoml.Model: ...
def get_runnable(bento_model: Model) -> type[PytorchModelRunnable]: ...
Import
import bentoml
# Via the public API
bento_model = bentoml.pytorch.save_model("my_model", model)
loaded = bentoml.pytorch.load_model("my_model:latest", device_id="cuda:0")
# Direct import
from bentoml._internal.frameworks.pytorch import save_model, load_model, get, get_runnable
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | Tag or str | Yes (save_model) | Name or tag for the model in the store |
| model | torch.nn.Module | Yes (save_model) | The PyTorch module instance to save |
| tag_like | str or Tag | Yes (get) | Tag to retrieve from the model store |
| bentoml_model | str, Tag, or Model | Yes (load_model) | Model identifier or object to load |
| device_id | str or None | No (default "cpu") | Device to load the model onto (e.g., "cpu", "cuda:0") |
| signatures | ModelSignaturesType or None | No | Method signatures; defaults to {"__call__": {"batchable": False}} |
| labels | dict[str, str] or None | No | User-defined labels for model management |
| custom_objects | dict[str, Any] or None | No | Additional objects to save with the model |
| external_modules | List[ModuleType] or None | No | Additional Python modules to bundle |
| metadata | dict[str, Any] or None | No | Custom metadata for the model |
| **torch_load_args | Any | No | Additional keyword arguments passed to torch.load() |
Outputs
| Name | Type | Description |
|---|---|---|
| bentoml.Model (save_model) | bentoml.Model | The saved model reference in the BentoML store |
| torch.nn.Module (load_model) | torch.nn.Module | The deserialized PyTorch module on the specified device |
| Model (get) | Model | BentoML Model metadata object |
| PytorchModelRunnable (get_runnable) | type[Runnable] | A Runnable class using shared PyTorch utilities |
Usage Examples
import torch
import torch.nn as nn
import bentoml
# Define a PyTorch model
class NGramLanguageModeler(nn.Module):
def __init__(self, vocab_size, embedding_dim, context_size):
super().__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(context_size * embedding_dim, 128)
self.linear2 = nn.Linear(128, vocab_size)
def forward(self, inputs):
embeds = self.embeddings(inputs).view((1, -1))
out = torch.relu(self.linear1(embeds))
out = self.linear2(out)
return torch.log_softmax(out, dim=1)
model = NGramLanguageModeler(len(vocab), 128, 2)
# Save to BentoML
tag = bentoml.pytorch.save_model("ngrams", model)
print(f"Saved: {tag}")
# Load on GPU
loaded = bentoml.pytorch.load_model("ngrams:latest", device_id="cuda:0")
# Save a Torch Hub model
resnet50 = torch.hub.load("pytorch/vision", "resnet50", pretrained=True)
tag = bentoml.pytorch.save_model("resnet50", resnet50)