Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Ollama Ollama MLXRunner Fast

From Leeroopedia
Knowledge Sources
Domains MLX Runtime, Tensor Operations
Last Updated 2025-02-15 00:00 GMT

Overview

Go bindings for MLX's optimized "fast" operations: scaled dot-product attention, layer normalization, RMS normalization, and rotary position embeddings (RoPE).

Description

ScaledDotProductAttention calls mlx_fast_scaled_dot_product_attention with causal masking mode. LayerNorm and RMSNorm are struct types with weight arrays that call their respective mlx_fast_* C functions. RoPE applies rotary position embeddings with configurable dimensions, traditional/non-traditional mode, base frequency, scale, and position offset.

Usage

Used in transformer model implementations for the most performance-critical operations. These leverage Metal GPU acceleration on Apple Silicon.

Code Reference

Source Location

  • Repository: Ollama
  • File: x/mlxrunner/mlx/fast.go
  • Lines: 1-74

Signature

func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array

type LayerNorm struct {
    Weight Array `weight:"weight"`
    Bias   Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array

type RMSNorm struct {
    Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array

type RoPE struct {
    Dims        int
    Traditional bool
    Base        float32 `json:"rope_theta"`
    Scale       float32
}
func (r RoPE) Forward(t *Array, offset int) *Array

Import

import "github.com/ollama/ollama/x/mlxrunner/mlx"

I/O Contract

Inputs

Name Type Required Description
query *Array Yes Query tensor [batch, heads, seq, dim]
key *Array Yes Key tensor [batch, heads, seq, dim]
value *Array Yes Value tensor [batch, heads, seq, dim]
scale float32 Yes Attention scale factor (1/sqrt(d))

Outputs

Name Type Description
output *Array Attention output tensor

Usage Examples

// Scaled dot-product attention
attn := mlx.ScaledDotProductAttention(q, k, v, nil, 1.0/math.Sqrt(64))

// RMS normalization
norm := mlx.RMSNorm{Weight: weightArray}
normalized := norm.Forward(input, 1e-5)

// Rotary position embeddings
rope := mlx.RoPE{Dims: 64, Base: 10000, Scale: 1.0}
rotated := rope.Forward(tensor, offset)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment