Implementation:Ollama Ollama Imagegen Flux2 VAE
| Knowledge Sources | |
|---|---|
| Domains | Image Generation, VAE |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Implements the AutoencoderKL VAE decoder for FLUX.2 Klein, converting latent representations to pixel images.
Description
The vae.go file implements the FLUX.2 VAE decoder (AutoencoderKLFlux2) with batch normalization, group normalization, and multi-resolution upsampling. It includes BatchNorm2D with both forward normalization and denormalization for latent processing, GroupNormLayer for spatial features, and the full decoder architecture with ResNet blocks, attention layers, and progressive upsampling through block_out_channels [128, 256, 512, 512]. The VAE uses NHWC (channels-last) format throughout for MLX GPU efficiency, with patch_size [2, 2] for latent-to-pixel unpacking. The config supports force_upcast to float32 for numerical stability during decoding.
Usage
Used as the final stage in the FLUX.2 pipeline to decode denoised latent tensors into RGB images, with support for tiled decoding on large images.
Code Reference
Source Location
- Repository: Ollama
- File: x/imagegen/models/flux2/vae.go
- Lines: 1-804
Signature
type VAEConfig struct {
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
NormNumGroups int32 `json:"norm_num_groups"` // 32
PatchSize []int32 `json:"patch_size"` // [2, 2]
}
type BatchNorm2D struct {
RunningMean *mlx.Array
RunningVar *mlx.Array
Weight *mlx.Array
Bias *mlx.Array
Eps float32
}
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array
type AutoencoderKLFlux2 struct { ... }
func (v *AutoencoderKLFlux2) Load(manifest *manifest.ModelManifest) error
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array) *mlx.Array
Import
import "github.com/ollama/ollama/x/imagegen/models/flux2"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| latents | *mlx.Array | Yes | Denoised latent tensor [B, H, W, C] in NHWC format |
Outputs
| Name | Type | Description |
|---|---|---|
| *mlx.Array | *mlx.Array | Decoded RGB image [B, C, H, W] with values in [0, 1] |
Usage Examples
vae := &flux2.AutoencoderKLFlux2{}
if err := vae.Load(manifest); err != nil {
return err
}
// Decode latents after denoising
image := vae.Decode(denoisedLatents)