Implementation:Ollama Ollama MLXRunner Fast
| Knowledge Sources | |
|---|---|
| Domains | MLX Runtime, Tensor Operations |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Go bindings for MLX's optimized "fast" operations: scaled dot-product attention, layer normalization, RMS normalization, and rotary position embeddings (RoPE).
Description
ScaledDotProductAttention calls mlx_fast_scaled_dot_product_attention with causal masking mode. LayerNorm and RMSNorm are struct types with weight arrays that call their respective mlx_fast_* C functions. RoPE applies rotary position embeddings with configurable dimensions, traditional/non-traditional mode, base frequency, scale, and position offset.
Usage
Used in transformer model implementations for the most performance-critical operations. These leverage Metal GPU acceleration on Apple Silicon.
Code Reference
Source Location
- Repository: Ollama
- File: x/mlxrunner/mlx/fast.go
- Lines: 1-74
Signature
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array
type RMSNorm struct {
Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array
Import
import "github.com/ollama/ollama/x/mlxrunner/mlx"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query | *Array | Yes | Query tensor [batch, heads, seq, dim] |
| key | *Array | Yes | Key tensor [batch, heads, seq, dim] |
| value | *Array | Yes | Value tensor [batch, heads, seq, dim] |
| scale | float32 | Yes | Attention scale factor (1/sqrt(d)) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | *Array | Attention output tensor |
Usage Examples
// Scaled dot-product attention
attn := mlx.ScaledDotProductAttention(q, k, v, nil, 1.0/math.Sqrt(64))
// RMS normalization
norm := mlx.RMSNorm{Weight: weightArray}
normalized := norm.Forward(input, 1e-5)
// Rotary position embeddings
rope := mlx.RoPE{Dims: 64, Base: 10000, Scale: 1.0}
rotated := rope.Forward(tensor, offset)