Implementation:Deepspeedai DeepSpeed XPU Adam Header
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep Learning, XPU Computing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Header file defining the SIMD-accelerated Adam optimizer class for Intel XPU devices with mixed precision support.
Description
This header defines the Adam_Optimizer class specifically for Intel XPU platforms, featuring SIMD acceleration with AVX2/AVX512 support and special handling for half-precision training. Unlike the standard CPU version, this implementation explicitly manages separate device parameter pointers (dev_params) to support efficient mixed-precision training where optimizer states are maintained in higher precision than the model parameters. The class provides Step_AVX template methods with span factors and includes bit-shifting logic to handle half-precision indexing efficiently.
Usage
Include this header when implementing Intel XPU-based Adam optimization with mixed precision training capabilities.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/xpu/includes/cpu_adam.h
Signature
typedef unsigned short ds_half_precision_t;
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0,
bool adamw_mode = true);
~Adam_Optimizer();
#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
void Step_AVX(size_t* rounded_size,
float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
#endif
void Step_1(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
void Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
void Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
inline void IncrementStep(size_t step, float beta1, float beta2);
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction);
};
Import
#include "cpu_adam.h"
#include "simd.h"
#include <cmath>
I/O Contract
Constructor Parameters
| Parameter | Type | Description |
|---|---|---|
| alpha | float | Learning rate (default: 1e-3) |
| betta1 | float | Exponential decay rate for first moment (default: 0.9) |
| betta2 | float | Exponential decay rate for second moment (default: 0.999) |
| eps | float | Small constant for numerical stability (default: 1e-8) |
| weight_decay | float | Weight decay coefficient (default: 0) |
| adamw_mode | bool | Use AdamW (decoupled) if true, Adam (L2) if false (default: true) |
Step_AVX Template Parameters
| Parameter | Type | Description |
|---|---|---|
| span | int | SIMD vector span factor (1, 4, or 8) |
| rounded_size | size_t* | Output: number of elements processed with SIMD |
| _params | float* | FP32 parameter buffer (in/out) |
| grads | float* | FP32 gradients buffer (in) |
| _exp_avg | float* | FP32 first moment estimates (in/out) |
| _exp_avg_sq | float* | FP32 second moment estimates (in/out) |
| param_size | size_t | Total number of parameters |
| dev_param | ds_half_precision_t* | Optional FP16 device parameters (in/out) |
| half_precision | bool | Enable FP16 mode for params/grads |
Mixed Precision Modes
| Mode | Description |
|---|---|
| FP32 | All buffers in FP32, dev_param = nullptr, half_precision = false |
| Mixed FP16 | Optimizer states in FP32, params/grads in FP16 via dev_param |
| Mixed BF16 | Similar to FP16 but with BFloat16 (AVX512 only) |
Usage Examples
#include "cpu_adam.h"
// Create Adam optimizer for XPU with mixed precision
Adam_Optimizer opt(
/* alpha = */ 0.001,
/* betta1 = */ 0.9,
/* betta2 = */ 0.999,
/* eps = */ 1e-8,
/* weight_decay = */ 0.01,
/* adamw_mode = */ true
);
// Update state
opt.IncrementStep(1, 0.9, 0.999);
opt.update_state(0.001, 1e-8, 0.01, true);
// Execute step with mixed precision (FP16 params, FP32 states)
size_t param_size = 1024;
float* fp32_buffer = new float[param_size]; // Master weights
float* grads = new float[param_size]; // FP32 gradients
float* exp_avg = new float[param_size]; // FP32 state
float* exp_avg_sq = new float[param_size]; // FP32 state
unsigned short* fp16_params = new unsigned short[param_size]; // FP16 device params
opt.Step_8(fp32_buffer, grads, exp_avg, exp_avg_sq, param_size,
fp16_params, true);
delete[] fp32_buffer;
delete[] grads;
delete[] exp_avg;
delete[] exp_avg_sq;
delete[] fp16_params;