Implementation:Bitsandbytes foundation Bitsandbytes Adam8bit
Appearance
| Sources | Repo: bitsandbytes, Paper: 8-bit Optimizers via Block-wise Quantization |
|---|---|
| Domains | Optimization |
Overview
Concrete tool for memory-efficient Adam optimization with 8-bit quantized states provided by the bitsandbytes library. Adam8bit is a drop-in replacement for torch.optim.Adam that stores momentum and variance in 8-bit format for approximately 75% optimizer state memory savings.
Description
Adam8bit inherits from Optimizer2State which inherits from Optimizer8bit. The inheritance chain is:
torch.optim.Optimizer
-> Optimizer8bit (step loop, state_dict, FSDP compat)
-> Optimizer2State (two-state init/update: momentum + variance)
-> Adam8bit (hardcodes optim_bits=8, optimizer_name="adam")
Key implementation details:
- Adam8bit hardcodes
optim_bits=8in its call toOptimizer2State.__init__, regardless of theoptim_bitsparameter in its own signature (which exists only for API compatibility and raisesValueErrorif set to anything other than the default 32). - The
optimizer_nameis set to"adam", which selects the Adam kernel in the underlying CUDA optimizer update functions. amsgradis not supported; passingamsgrad=TrueraisesValueError.- The
Optimizer2Stateclass handles two state tensors (momentum asstate1and variance asstate2). - On the first step, states are initialized to zeros via
init_state(). For 8-bit mode, this createsuint8zero tensors with corresponding quantization maps and absmax scaling factors. - On subsequent steps,
update_step()dispatches to one of three CUDA kernel paths depending on precision and block mode:- 32-bit path:
optimizer_update_32bit(for tensors belowmin_8bit_size) - 8-bit non-blockwise path:
optimizer_update_8bit - 8-bit blockwise path:
optimizer_update_8bit_blockwise(default and recommended)
- 32-bit path:
Code Reference
| Source | File | Lines |
|---|---|---|
| bitsandbytes repo | bitsandbytes/optim/adam.py |
Adam8bit L70-139 |
| bitsandbytes repo | bitsandbytes/optim/optimizer.py |
Optimizer2State L384-625 |
Class signature:
class Adam8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
Import:
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-3)
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
params |
iterable | Yes | -- | Model parameters to optimize |
lr |
float | No | 1e-3 | Learning rate |
betas |
tuple(float, float) | No | (0.9, 0.999) | Decay rates for first and second moment estimates |
eps |
float | No | 1e-8 | Epsilon for numerical stability in denominator |
weight_decay |
float | No | 0 | Weight decay (L2 penalty) |
min_8bit_size |
int | No | 4096 | Minimum tensor size (number of elements) for 8-bit quantization; smaller tensors use 32-bit |
percentile_clipping |
int | No | 100 | Percentile for gradient clipping (100 disables clipping) |
block_wise |
bool | No | True | Whether to use block-wise quantization (recommended) |
is_paged |
bool | No | False | Whether to use paged memory for GPU-to-CPU offload |
Outputs
| Output | Type | Description |
|---|---|---|
| parameter updates | in-place | Model parameters are updated in-place |
state1 |
torch.uint8 tensor | Momentum (first moment), stored in 8-bit quantized format |
state2 |
torch.uint8 tensor | Variance (second moment), stored in 8-bit quantized format |
absmax1 |
torch.float32 tensor | Per-block scaling factors for momentum |
absmax2 |
torch.float32 tensor | Per-block scaling factors for variance |
qmap1 |
torch.Tensor | Signed dynamic quantization map (256 levels) for momentum |
qmap2 |
torch.Tensor | Unsigned dynamic quantization map (256 levels) for variance |
Usage Examples
Standard training with Adam8bit:
import torch
import bitsandbytes as bnb
model = MyModel().cuda()
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
weight_decay=0.01,
)
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
outputs = model(batch["input_ids"].cuda())
loss = loss_fn(outputs, batch["labels"].cuda())
loss.backward()
optimizer.step()
With paged memory for large models:
optimizer = bnb.optim.PagedAdam8bit(
model.parameters(),
lr=1e-3,
)
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment