Implementation:Ollama Ollama Imagegen Safetensors Loader
| Knowledge Sources | |
|---|---|
| Domains | Image Generation, Model Loading |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Provides reflection-based weight loading from safetensors files into Go model structs, using struct tag-driven tensor mapping.
Description
The loader.go file implements a generic weight loading system that maps tensor names from safetensors files to Go struct fields using "weight" struct tags. The LoadModule function recursively walks struct fields, matching weight tags like `weight:"model.layers"` to tensor names, and automatically creates nn.Linear, nn.Embedding, nn.RMSNorm, nn.LayerNorm, and nn.QuantizedLinear instances. It handles layer indexing (weight:"model.layers" maps to model.layers.0.*, model.layers.1.*, etc.), optional fields (weight:"bias,optional"), and quantization by detecting _scale/_qbias tensor suffixes. The system supports dtype conversion and memory-efficient loading by batching mlx.Eval calls.
Usage
Used by all model loaders (Llama, Gemma3, GPT-OSS, etc.) to populate model structs from safetensors weight files with minimal boilerplate.
Code Reference
Source Location
- Repository: Ollama
- File: x/imagegen/safetensors/loader.go
- Lines: 1-429
Signature
func LoadModule(module interface{}, weights *ModelWeights, prefix string) error
func Collect(module interface{}) []*mlx.Array
Import
import "github.com/ollama/ollama/x/imagegen/safetensors"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| module | interface{} | Yes | Pointer to model struct with weight tags |
| weights | *ModelWeights | Yes | Loaded weight index from safetensors files |
| prefix | string | Yes | Weight name prefix (e.g., "model" or "") |
Outputs
| Name | Type | Description |
|---|---|---|
| error | error | Error if required weights are missing |
Usage Examples
weights, err := safetensors.LoadModelWeights("/path/to/model")
if err != nil {
return err
}
m := &Model{
Layers: make([]*Layer, numLayers),
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return err
}
// Evaluate all loaded weights
mlx.Eval(safetensors.Collect(m)...)