Implementation:Ollama Ollama Imagegen Cache Step
| Knowledge Sources | |
|---|---|
| Domains | Image Generation, Caching |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Implements step-level caching for diffusion denoising, enabling layer output reuse between consecutive denoising steps.
Description
The step.go file implements StepCache, which caches transformer layer outputs across diffusion denoising steps based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024). Shallow layers change little between consecutive steps, so their outputs can be cached and reused on non-refresh steps. The cache supports both single-stream (Get/Set) and dual-stream architectures (Get/Set for image stream, Get2/Set2 for text stream), an optional constant cache (e.g., text embeddings), interval-based refresh logic (ShouldRefresh), and proper memory management (Free releases all cached arrays, Set frees previous values before storing).
Usage
Used by Z-Image and FLUX.2 transformers to skip redundant layer computations during the denoising loop, reducing inference time.
Code Reference
Source Location
- Repository: Ollama
- File: x/imagegen/cache/step.go
- Lines: 1-164
Signature
type StepCache struct {
layers []*mlx.Array
layers2 []*mlx.Array
constant *mlx.Array
}
func NewStepCache(numLayers int) *StepCache
func (c *StepCache) ShouldRefresh(step, interval int) bool
func (c *StepCache) Get(layer int) *mlx.Array
func (c *StepCache) Set(layer int, arr *mlx.Array)
func (c *StepCache) Get2(layer int) *mlx.Array
func (c *StepCache) Set2(layer int, arr *mlx.Array)
func (c *StepCache) GetConstant() *mlx.Array
func (c *StepCache) SetConstant(arr *mlx.Array)
func (c *StepCache) Free()
Import
import "github.com/ollama/ollama/x/imagegen/cache"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| numLayers | int | Yes | Number of layers to cache |
| step | int | Yes | Current denoising step |
| interval | int | Yes | Refresh interval (e.g., every 3 steps) |
Outputs
| Name | Type | Description |
|---|---|---|
| *StepCache | *StepCache | Step cache for layer output reuse |
| bool | bool | Whether cache should be refreshed at this step |
Usage Examples
sc := cache.NewStepCache(15) // cache first 15 layers
defer sc.Free()
for step := 0; step < numSteps; step++ {
refresh := sc.ShouldRefresh(step, 3) // refresh every 3 steps
for i, layer := range layers {
if i < 15 && !refresh && sc.Get(i) != nil {
output = sc.Get(i) // reuse cached output
} else {
output = layer.Forward(input)
if i < 15 && refresh {
sc.Set(i, output)
}
}
}
}