Implementation:Ollama Ollama XModels NN
| Knowledge Sources | |
|---|---|
| Domains | MLX Runtime, Neural Network Layers |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Defines reusable neural network layer types for the MLX runner, including linear, quantized linear, normalization, embedding, and attention utility functions.
Description
Provides fundamental building blocks: Linear (affine transform y = x @ W.T + b), QuantizedLinear (quantized weights with dequantization via QuantizedMatmul), RMSNorm, LayerNorm (with bias), Embedding (lookup table), and MultiLinear (per-head projections). Includes attention helpers: RepeatKV for grouped query attention (repeating K/V heads to match Q heads), ApplyCausalMask for standard causal masking, and ApplyCausalMaskWithOffset for cached attention with position offsets.
Usage
Used as the layer building blocks by all model implementations (e.g., GLM4-MoE-Lite). Models compose these layers into transformer blocks.
Code Reference
Source Location
- Repository: Ollama
- File: x/models/nn/nn.go
- Lines: 1-188
Signature
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32
}
type Linear struct {
Weight *mlx.Array
Bias *mlx.Array
}
func NewLinear(weight, bias *mlx.Array) *Linear
type QuantizedLinear struct {
Weight, Scales, QBiases, Bias *mlx.Array
GroupSize, Bits int
Mode string
}
func NewQuantizedLinear(weight, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear
type RMSNorm struct {
Weight *mlx.Array
Eps float32
}
type Embedding struct {
Weight *mlx.Array
}
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array
func ApplyCausalMask(scores *mlx.Array) *mlx.Array
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array
Import
import "github.com/ollama/ollama/x/models/nn"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | *mlx.Array | Yes | Input tensor |
| weight | *mlx.Array | Yes | Layer weight tensor |
| bias | *mlx.Array | No | Optional bias tensor |
Outputs
| Name | Type | Description |
|---|---|---|
| output | *mlx.Array | Transformed tensor |
Usage Examples
// Standard linear layer
linear := nn.NewLinear(weight, bias)
out := linear.Forward(input)
// Quantized linear layer
qlinear := nn.NewQuantizedLinear(weight, nil, 32, 4, "affine")
out := qlinear.Forward(input)
// RMS normalization
norm := &nn.RMSNorm{Weight: w, Eps: 1e-5}
normalized := norm.Forward(input, 0)
// Grouped query attention helper
expandedKV := nn.RepeatKV(kv, 4) // Repeat 4x for GQA