Implementation:Ollama Ollama MLXRunner Sample
| Knowledge Sources | |
|---|---|
| Domains | MLX Runtime, Sampling |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Implements token sampling strategies for the MLX runner, including greedy decoding, temperature scaling, top-k filtering, and chainable sampler composition.
Description
Defines a Sampler interface with a single Sample(*mlx.Array) *mlx.Array method. New constructs an appropriate sampler: greedy (argmax) when temperature is 0, otherwise a chain of filters. TopK masks out all but the k highest-probability tokens using argpartition. Temperature divides logits by the temperature and samples categorically. TopP and MinP are declared but not yet implemented.
Usage
Called by the text generation pipeline to select the next token from model logits at each generation step.
Code Reference
Source Location
- Repository: Ollama
- File: x/mlxrunner/sample/sample.go
- Lines: 1-77
Signature
type Sampler interface {
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k int) Sampler
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array
type TopK int
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array
Import
import "github.com/ollama/ollama/x/mlxrunner/sample"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| temp | float32 | Yes | Temperature (0 for greedy) |
| top_p | float32 | No | Nucleus sampling threshold (not yet implemented) |
| min_p | float32 | No | Minimum probability threshold (not yet implemented) |
| top_k | int | No | Top-k filter count (0 to disable) |
Outputs
| Name | Type | Description |
|---|---|---|
| sampler | Sampler | Configured sampler ready for use |
Usage Examples
// Greedy decoding
sampler := sample.New(0, 0, 0, 0)
token := sampler.Sample(logits)
// Temperature + Top-K sampling
sampler := sample.New(0.7, 0, 0, 40)
token := sampler.Sample(logits)