Implementation:Ollama Ollama Imagegen NN
| Knowledge Sources | |
|---|---|
| Domains | Image Generation, Neural Networks |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Provides foundational neural network layer types (Linear, QuantizedLinear, RMSNorm, Embedding, LayerNorm) used by all models in the imagegen subsystem.
Description
The nn.go file defines the core building blocks for neural network models in the imagegen subsystem. Linear applies affine transformations (x @ W.T + b) using AddMM for fused bias addition. QuantizedLinear supports multiple quantization modes (affine with scale + bias, nvfp4 with E4M3 scales) via mlx.QuantizedMatmul. RMSNorm and LayerNorm provide normalization layers. Embedding performs lookup-based word embeddings. Utility functions include RepeatKV for GQA head expansion (inserting and tiling along the head dimension), ApplyCausalMask for lower-triangular attention masking, and ApplyCausalMaskWithOffset for cached attention with prior context. Linear supports runtime conversion to QuantizedLinear via ToQuantized.
Usage
Used by every model architecture (Llama, Gemma3, GLM4, GPT-OSS, Flux2 transformer, Z-Image transformer) as the foundational layer types.
Code Reference
Source Location
- Repository: Ollama
- File: x/imagegen/nn/nn.go
- Lines: 1-260
Signature
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32
}
type Linear struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias,optional"`
}
type QuantizedLinear struct {
Weight *mlx.Array
Scales *mlx.Array
QBiases *mlx.Array
Bias *mlx.Array
GroupSize int
Bits int
Mode string
}
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32
}
type Embedding struct {
Weight *mlx.Array `weight:"weight"`
}
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array
func ApplyCausalMask(scores *mlx.Array) *mlx.Array
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array
Import
import "github.com/ollama/ollama/x/imagegen/nn"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | *mlx.Array | Yes | Input tensor for forward pass |
| weight | *mlx.Array | Yes | Weight tensor [out_features, in_features] |
| eps | float32 | Yes | Normalization epsilon for RMSNorm |
Outputs
| Name | Type | Description |
|---|---|---|
| *mlx.Array | *mlx.Array | Transformed output tensor |
Usage Examples
linear := nn.NewLinear(weight, bias)
output := linear.Forward(input)
qlinear := nn.NewQuantizedLinear(weight, bias, 64, 4, "affine")
output := qlinear.Forward(input)
norm := nn.NewRMSNorm(weight, 1e-6)
output := norm.Forward(input, 0)