Implementation:Deepspeedai DeepSpeed CPU Adam Header
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep Learning, CPU Computing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Header file defining the SIMD-accelerated Adam/AdamW optimizer class with AVX2/AVX512 support for high-performance CPU training.
Description
This header defines the Adam_Optimizer class with template-based SIMD implementations supporting both Adam (L2 regularization) and AdamW (decoupled weight decay) modes. The class features Step_AVX template methods utilizing AVX intrinsics for vectorized operations with span factors (1, 4, 8) for efficient batch processing. It supports bias correction, mixed precision training with FP16/BFloat16/FP32 types, and includes state management for exponential moving averages with automatic beta power computation.
Usage
Include this header when implementing or extending CPU-based Adam/AdamW optimization with SIMD acceleration capabilities.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/includes/cpu_adam.h
Signature
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0,
bool adamw_mode = true);
~Adam_Optimizer();
#if defined(__AVX512__) or defined(__AVX256__)
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
void Step_AVX(size_t* rounded_size,
ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
ds_state_precision_t* _exp_avg_sq,
size_t param_size);
#endif
template <typename ds_params_precision_t, typename ds_state_precision_t>
void Step_1(ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size);
template <typename ds_params_precision_t, typename ds_state_precision_t>
void Step_4(ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size);
template <typename ds_params_precision_t, typename ds_state_precision_t>
void Step_8(ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size);
inline void IncrementStep(size_t step, float beta1, float beta2);
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction);
};
Import
#include "cpu_adam.h"
#include "simd.h"
I/O Contract
Constructor Parameters
| Parameter | Type | Description |
|---|---|---|
| alpha | float | Learning rate (default: 1e-3) |
| betta1 | float | Exponential decay rate for first moment (default: 0.9) |
| betta2 | float | Exponential decay rate for second moment (default: 0.999) |
| eps | float | Small constant for numerical stability (default: 1e-8) |
| weight_decay | float | Weight decay coefficient (default: 0) |
| adamw_mode | bool | Use AdamW (decoupled) if true, Adam (L2) if false (default: true) |
Step_AVX Template Parameters
| Parameter | Type | Description |
|---|---|---|
| span | int | SIMD vector span factor (1, 4, or 8) |
| rounded_size | size_t* | Output: number of elements processed with SIMD |
| _params | ds_params_precision_t* | Model parameters array (in/out) |
| grads | ds_params_precision_t* | Gradients array (in) |
| _exp_avg | ds_state_precision_t* | First moment estimates (in/out) |
| _exp_avg_sq | ds_state_precision_t* | Second moment estimates (in/out) |
| param_size | size_t | Total number of parameters |
Supported Type Combinations
| Parameters Type | State Type | Description |
|---|---|---|
| c10::Half | float | FP16 parameters, FP32 state |
| c10::Half | c10::Half | FP16 parameters, FP16 state |
| c10::BFloat16 | float | BF16 parameters, FP32 state (AVX512 only) |
| c10::BFloat16 | c10::BFloat16 | BF16 parameters, BF16 state (AVX512 only) |
| float | float | FP32 parameters, FP32 state |
Usage Examples
#include "cpu_adam.h"
// Create Adam optimizer instance with AdamW mode
Adam_Optimizer opt(
/* alpha = */ 0.001,
/* betta1 = */ 0.9,
/* betta2 = */ 0.999,
/* eps = */ 1e-8,
/* weight_decay = */ 0.01,
/* adamw_mode = */ true
);
// Update state for current step
opt.IncrementStep(1, 0.9, 0.999);
opt.update_state(0.001, 1e-8, 0.01, true);
// Execute optimizer step with FP32
size_t param_size = 1024;
float* params = new float[param_size];
float* grads = new float[param_size];
float* exp_avg = new float[param_size];
float* exp_avg_sq = new float[param_size];
opt.Step_8(params, grads, exp_avg, exp_avg_sq, param_size);
delete[] params;
delete[] grads;
delete[] exp_avg;
delete[] exp_avg_sq;