Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed CPU Lion Header

From Leeroopedia


Knowledge Sources
Domains Optimization, Deep Learning, CPU Computing
Last Updated 2026-02-09 00:00 GMT

Overview

Header file defining the SIMD-accelerated Lion optimizer class with AVX2/AVX512 support for efficient CPU-based training.

Description

This header defines the Lion_Optimizer class, implementing the Lion (EvoLved Sign Momentum) optimization algorithm with SIMD acceleration. Lion uses sign-based updates with momentum interpolation, requiring less memory than Adam (one momentum buffer vs. two) while often achieving comparable or better performance. The class features Step_AVX template methods utilizing AVX intrinsics with span factors (1, 4, 8) for vectorized operations, and includes bit manipulation operations (AND, XOR) for efficient sign extraction.

Usage

Include this header when implementing CPU-based Lion optimization for memory-efficient training with adaptive sign-based updates.

Code Reference

Source Location

Signature

class Lion_Optimizer {
public:
    Lion_Optimizer(float alpha = 1e-3,
                   float betta1 = 0.9,
                   float betta2 = 0.999,
                   float weight_decay = 0);
    ~Lion_Optimizer();

#if defined(__AVX512__) or defined(__AVX256__)
    template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
    void Step_AVX(size_t* rounded_size,
                  ds_params_precision_t* _params,
                  ds_params_precision_t* grads,
                  ds_state_precision_t* _exp_avg,
                  size_t param_size);
#endif

    template <typename ds_params_precision_t, typename ds_state_precision_t>
    void Step_1(ds_params_precision_t* _params,
                ds_params_precision_t* grads,
                ds_state_precision_t* _exp_avg,
                size_t _param_size);

    template <typename ds_params_precision_t, typename ds_state_precision_t>
    void Step_4(ds_params_precision_t* _params,
                ds_params_precision_t* grads,
                ds_state_precision_t* _exp_avg,
                size_t _param_size);

    template <typename ds_params_precision_t, typename ds_state_precision_t>
    void Step_8(ds_params_precision_t* _params,
                ds_params_precision_t* grads,
                ds_state_precision_t* _exp_avg,
                size_t _param_size);

    inline void IncrementStep(size_t step, float beta1, float beta2);
    inline void update_state(float lr, float weight_decay);

private:
    float _alpha;
    float _betta1;
    float _betta2;
    float _weight_decay;
    size_t _step;
};

Import

#include "cpu_lion.h"
#include "simd.h"

I/O Contract

Constructor Parameters

Parameter Type Description
alpha float Learning rate (default: 1e-3)
betta1 float Momentum interpolation coefficient for update (default: 0.9)
betta2 float Momentum coefficient for EMA (default: 0.999)
weight_decay float Weight decay coefficient (default: 0)

Step_AVX Template Parameters

Parameter Type Description
span int SIMD vector span factor (1, 4, or 8)
rounded_size size_t* Output: number of elements processed with SIMD
_params ds_params_precision_t* Model parameters array (in/out)
grads ds_params_precision_t* Gradients array (in)
_exp_avg ds_state_precision_t* Momentum buffer (in/out)
param_size size_t Total number of parameters

Supported Type Combinations

Parameters Type State Type Description
c10::Half float FP16 parameters, FP32 state
c10::Half c10::Half FP16 parameters, FP16 state
c10::BFloat16 float BF16 parameters, FP32 state (AVX512 only)
c10::BFloat16 c10::BFloat16 BF16 parameters, BF16 state (AVX512 only)
float float FP32 parameters, FP32 state

Usage Examples

#include "cpu_lion.h"

// Create Lion optimizer instance
Lion_Optimizer opt(
    /* alpha = */ 0.001,
    /* betta1 = */ 0.9,
    /* betta2 = */ 0.999,
    /* weight_decay = */ 0.01
);

// Update state for current step
opt.IncrementStep(1, 0.9, 0.999);
opt.update_state(0.001, 0.01);

// Execute optimizer step with FP32
size_t param_size = 1024;
float* params = new float[param_size];
float* grads = new float[param_size];
float* exp_avg = new float[param_size];

opt.Step_8(params, grads, exp_avg, param_size);

delete[] params;
delete[] grads;
delete[] exp_avg;

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment