Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Ollama Ollama Imagegen Flux2 VAE

From Leeroopedia
Knowledge Sources
Domains Image Generation, VAE
Last Updated 2025-02-15 00:00 GMT

Overview

Implements the AutoencoderKL VAE decoder for FLUX.2 Klein, converting latent representations to pixel images.

Description

The vae.go file implements the FLUX.2 VAE decoder (AutoencoderKLFlux2) with batch normalization, group normalization, and multi-resolution upsampling. It includes BatchNorm2D with both forward normalization and denormalization for latent processing, GroupNormLayer for spatial features, and the full decoder architecture with ResNet blocks, attention layers, and progressive upsampling through block_out_channels [128, 256, 512, 512]. The VAE uses NHWC (channels-last) format throughout for MLX GPU efficiency, with patch_size [2, 2] for latent-to-pixel unpacking. The config supports force_upcast to float32 for numerical stability during decoding.

Usage

Used as the final stage in the FLUX.2 pipeline to decode denoised latent tensors into RGB images, with support for tiled decoding on large images.

Code Reference

Source Location

  • Repository: Ollama
  • File: x/imagegen/models/flux2/vae.go
  • Lines: 1-804

Signature

type VAEConfig struct {
	BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
	LatentChannels   int32   `json:"latent_channels"`    // 32
	LayersPerBlock   int32   `json:"layers_per_block"`   // 2
	NormNumGroups    int32   `json:"norm_num_groups"`    // 32
	PatchSize        []int32 `json:"patch_size"`         // [2, 2]
}

type BatchNorm2D struct {
	RunningMean *mlx.Array
	RunningVar  *mlx.Array
	Weight      *mlx.Array
	Bias        *mlx.Array
	Eps         float32
}

func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array

type AutoencoderKLFlux2 struct { ... }
func (v *AutoencoderKLFlux2) Load(manifest *manifest.ModelManifest) error
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array) *mlx.Array

Import

import "github.com/ollama/ollama/x/imagegen/models/flux2"

I/O Contract

Inputs

Name Type Required Description
latents *mlx.Array Yes Denoised latent tensor [B, H, W, C] in NHWC format

Outputs

Name Type Description
*mlx.Array *mlx.Array Decoded RGB image [B, C, H, W] with values in [0, 1]

Usage Examples

vae := &flux2.AutoencoderKLFlux2{}
if err := vae.Load(manifest); err != nil {
    return err
}

// Decode latents after denoising
image := vae.Decode(denoisedLatents)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment