Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Workflow:Bitsandbytes foundation Bitsandbytes 8bit Optimizer Training

From Leeroopedia


Knowledge Sources
Domains LLMs, Training, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

End-to-end process for training a PyTorch model using 8-bit quantized optimizer states, reducing optimizer memory consumption by approximately 75% while maintaining 32-bit training performance.

Description

This workflow demonstrates how to replace standard 32-bit optimizers (Adam, AdamW, SGD, etc.) with their bitsandbytes 8-bit equivalents. The 8-bit optimizers use block-wise dynamic quantization to maintain optimizer states (momentum, variance) in INT8 format. States are dequantized to FP32 before each update step, the standard optimizer algorithm is applied, and the updated states are re-quantized to INT8. This is transparent to the training loop. The workflow also covers paged optimizers that can offload optimizer states from GPU to CPU when memory pressure is detected, and the GlobalOptimManager for per-parameter optimizer configuration overrides.

Usage

Execute this workflow when training any PyTorch model and optimizer memory is a bottleneck. This is especially valuable for large model fine-tuning where the optimizer states (e.g., Adam's first and second moments) consume significant GPU memory. The 8-bit optimizers are drop-in replacements for their PyTorch counterparts.

Execution Steps

Step 1: Define Model and Parameters

Create or load the model to be trained. Identify which parameters require gradients and will be passed to the optimizer. Optionally, use the GlobalOptimManager to register parameters and configure per-parameter overrides (e.g., keeping embedding weights in 32-bit precision while using 8-bit for other parameters).

Key considerations:

  • The GlobalOptimManager singleton allows per-parameter control of optim_bits, percentile_clipping, and other settings
  • StableEmbedding layers automatically register themselves for 32-bit optimizer states via the GlobalOptimManager
  • Parameters must be registered with the manager before the optimizer is created

Step 2: Initialize 8-bit Optimizer

Replace the standard PyTorch optimizer with its bitsandbytes 8-bit variant. The available optimizers include Adam8bit, AdamW8bit, SGD8bit, Lion8bit, LAMB8bit, LARS8bit, RMSprop8bit, Adagrad8bit, and AdEMAMix8bit. Each optimizer accepts the same hyperparameters as its PyTorch equivalent, plus additional options like optim_bits (8 or 32), percentile_clipping, and is_paged (for CPU offloading).

Key considerations:

  • The bnb.optim.Adam constructor accepts optim_bits=8 to enable 8-bit states (or use bnb.optim.Adam8bit directly)
  • Paged variants (e.g., PagedAdamW8bit) enable automatic GPU-to-CPU state offloading under memory pressure
  • The optimizer hierarchy is: Optimizer8bit base -> Optimizer2State (Adam-family) or Optimizer1State (SGD-family) -> specific optimizer

Step 3: Run Training Loop

Execute the standard PyTorch training loop (forward pass, loss computation, backward pass, optimizer step). The Optimizer8bit.step() method transparently handles state quantization. On the first call, optimizer states are initialized in 32-bit and immediately quantized to 8-bit. On subsequent calls, states are dequantized from INT8 to FP32, the optimizer update is computed in FP32, and the updated states are re-quantized to INT8.

Key considerations:

  • The dequantize-update-quantize cycle is fully transparent to user code
  • Block-wise dynamic quantization uses blocks of 2048 elements with per-block scaling
  • The quantization uses dynamic exponent lookup tables for non-linear mapping
  • For paged optimizers, states may be prefetched from CPU before the update step

Step 4: Handle State Serialization

When saving checkpoints, the optimizer state_dict contains 8-bit quantized states along with their quantization metadata (quantization maps and absmax values). Loading from a checkpoint restores the quantized states directly without re-quantization.

Key considerations:

  • The state_dict includes both the quantized state tensors and their QuantState metadata
  • Checkpoints from 8-bit optimizers are smaller than 32-bit equivalents
  • States can be loaded across different devices (CPU/GPU)

Step 5: Monitor and Validate Training

Verify that training loss converges comparably to 32-bit optimization. The 8-bit optimizer should produce nearly identical training dynamics. Use standard logging and evaluation practices to confirm model quality.

Key considerations:

  • Empirically, 8-bit optimizers match 32-bit performance on virtually all benchmarks
  • If instability is observed, increase optim_bits to 32 for specific sensitive parameters via GlobalOptimManager
  • Percentile clipping can further stabilize training by dynamically scaling the gradient range

Execution Diagram

GitHub URL

Workflow Repository