Implementation:Ollama Ollama Imagegen ZImage VAE
| Knowledge Sources | |
|---|---|
| Domains | Image Generation, VAE |
| Last Updated | 2025-02-15 00:00 GMT |
Overview
Implements the Variational Autoencoder (VAE) decoder for the Z-Image pipeline, converting latent representations to pixel images with NHWC layout.
Description
The vae.go file provides the Z-Image VAE with an encoder (for image-to-image) and decoder (for latent-to-image). The architecture uses GroupNorm for normalization, ResNet blocks with optional shortcut convolutions for channel changes, and a multi-resolution decoder with progressive upsampling. The VAE config defines scaling factors for normalization, block_out_channels [128, 256, 512, 512], and operates in NHWC (channels-last) format for MLX efficiency. The file provides both Encode (for img2img latent extraction) and Decode (for final image generation) with proper scale/shift denormalization. Conv2D uses cross-correlation mode matching PyTorch conventions.
Usage
Used as the final stage in the Z-Image pipeline, decoding denoised latents to RGB images and optionally encoding input images to latent space for editing.
Code Reference
Source Location
- Repository: Ollama
- File: x/imagegen/models/zimage/vae.go
- Lines: 1-822
Signature
type AutoencoderKLConfig struct {
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
DownBlockTypes []string `json:"down_block_types"`
UpBlockTypes []string `json:"up_block_types"`
LatentChannels int32 `json:"latent_channels"` // 16
LayersPerBlock int32 `json:"layers_per_block"`
ScalingFactor float32 `json:"scaling_factor"`
ShiftFactor float32 `json:"shift_factor"`
}
type GroupNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
type AutoencoderKL struct { ... }
func (v *AutoencoderKL) Load(manifest *manifest.ModelManifest) error
func (v *AutoencoderKL) Decode(z *mlx.Array) *mlx.Array
func (v *AutoencoderKL) Encode(x *mlx.Array) *mlx.Array
Import
import "github.com/ollama/ollama/x/imagegen/models/zimage"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| z | *mlx.Array | Yes | Latent tensor [B, H, W, C] in NHWC format |
Outputs
| Name | Type | Description |
|---|---|---|
| *mlx.Array | *mlx.Array | Decoded image [B, H*8, W*8, 3] in NHWC format |
Usage Examples
vae := &zimage.AutoencoderKL{}
if err := vae.Load(manifest); err != nil {
return err
}
// Decode latents to image (8x upsampling)
image := vae.Decode(denoisedLatents)
// For img2img: encode to latent space
latents := vae.Encode(inputImage)