Implementation:Deepspeedai DeepSpeed CPU Lion Header
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep Learning, CPU Computing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Header file defining the SIMD-accelerated Lion optimizer class with AVX2/AVX512 support for efficient CPU-based training.
Description
This header defines the Lion_Optimizer class, implementing the Lion (EvoLved Sign Momentum) optimization algorithm with SIMD acceleration. Lion uses sign-based updates with momentum interpolation, requiring less memory than Adam (one momentum buffer vs. two) while often achieving comparable or better performance. The class features Step_AVX template methods utilizing AVX intrinsics with span factors (1, 4, 8) for vectorized operations, and includes bit manipulation operations (AND, XOR) for efficient sign extraction.
Usage
Include this header when implementing CPU-based Lion optimization for memory-efficient training with adaptive sign-based updates.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/includes/cpu_lion.h
Signature
class Lion_Optimizer {
public:
Lion_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float weight_decay = 0);
~Lion_Optimizer();
#if defined(__AVX512__) or defined(__AVX256__)
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
void Step_AVX(size_t* rounded_size,
ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
size_t param_size);
#endif
template <typename ds_params_precision_t, typename ds_state_precision_t>
void Step_1(ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
size_t _param_size);
template <typename ds_params_precision_t, typename ds_state_precision_t>
void Step_4(ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
size_t _param_size);
template <typename ds_params_precision_t, typename ds_state_precision_t>
void Step_8(ds_params_precision_t* _params,
ds_params_precision_t* grads,
ds_state_precision_t* _exp_avg,
size_t _param_size);
inline void IncrementStep(size_t step, float beta1, float beta2);
inline void update_state(float lr, float weight_decay);
private:
float _alpha;
float _betta1;
float _betta2;
float _weight_decay;
size_t _step;
};
Import
#include "cpu_lion.h"
#include "simd.h"
I/O Contract
Constructor Parameters
| Parameter | Type | Description |
|---|---|---|
| alpha | float | Learning rate (default: 1e-3) |
| betta1 | float | Momentum interpolation coefficient for update (default: 0.9) |
| betta2 | float | Momentum coefficient for EMA (default: 0.999) |
| weight_decay | float | Weight decay coefficient (default: 0) |
Step_AVX Template Parameters
| Parameter | Type | Description |
|---|---|---|
| span | int | SIMD vector span factor (1, 4, or 8) |
| rounded_size | size_t* | Output: number of elements processed with SIMD |
| _params | ds_params_precision_t* | Model parameters array (in/out) |
| grads | ds_params_precision_t* | Gradients array (in) |
| _exp_avg | ds_state_precision_t* | Momentum buffer (in/out) |
| param_size | size_t | Total number of parameters |
Supported Type Combinations
| Parameters Type | State Type | Description |
|---|---|---|
| c10::Half | float | FP16 parameters, FP32 state |
| c10::Half | c10::Half | FP16 parameters, FP16 state |
| c10::BFloat16 | float | BF16 parameters, FP32 state (AVX512 only) |
| c10::BFloat16 | c10::BFloat16 | BF16 parameters, BF16 state (AVX512 only) |
| float | float | FP32 parameters, FP32 state |
Usage Examples
#include "cpu_lion.h"
// Create Lion optimizer instance
Lion_Optimizer opt(
/* alpha = */ 0.001,
/* betta1 = */ 0.9,
/* betta2 = */ 0.999,
/* weight_decay = */ 0.01
);
// Update state for current step
opt.IncrementStep(1, 0.9, 0.999);
opt.update_state(0.001, 0.01);
// Execute optimizer step with FP32
size_t param_size = 1024;
float* params = new float[param_size];
float* grads = new float[param_size];
float* exp_avg = new float[param_size];
opt.Step_8(params, grads, exp_avg, param_size);
delete[] params;
delete[] grads;
delete[] exp_avg;