Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed RISCV SHM

From Leeroopedia


Knowledge Sources
Domains SIMD, CPU_Optimization, RISC-V, Communication
Last Updated 2026-02-09 00:00 GMT

Overview

RISC-V Vector Extension (RVV) SIMD primitives for shared memory allreduce operations with dynamic vector length support.

Description

The RISCV_SHM header provides a hardware abstraction layer for SIMD-accelerated reduction operations on RISC-V processors with the V (Vector) extension. Unlike x86_64's fixed-width AVX-512 vectors, RISC-V Vector extensions support variable-length vectors determined at runtime via the vsetvl instruction, allowing the same code to run efficiently on hardware with different vector register lengths (128-bit, 256-bit, 512-bit, or larger).

The implementation provides inline conversion functions between BFloat16/Float16 and Float32 formats, which are essential for maintaining numerical accuracy during reduction operations. BFloat16 conversion is implemented by shifting the 16-bit value left by 16 bits to form the upper half of a 32-bit float, while also handling NaN values with proper rounding bias. Float16 conversion uses native RVV instructions (vfwcvt for widening, vfncvt for narrowing) when the Zvfh extension is available.

All functions are marked with GCC's target attribute to ensure proper instruction selection and include explicit vector length parameters passed to each intrinsic. The header defines a comprehensive set of macros (VLOAD_*, VSTORE_*, VADD_*, CVT_*) that map generic operations to RISC-V Vector intrinsics, allowing the main SHM allreduce code to remain architecture-neutral.

Usage

Use RISCV_SHM primitives when compiling DeepSpeed for RISC-V processors with the V extension, particularly for inference workloads requiring low-latency CPU communication. The header is automatically included when TARGET_RISCV is defined during compilation.

Code Reference

Source Location

Signature

// BFloat16 conversion (requires RVV base)
inline vfloat32m2_t cvt_bf16_to_fp32(vuint16m1_t src, size_t vl)
    __attribute__((target("arch=+v")));
inline vuint16m1_t cvt_fp32_to_bf16(vfloat32m2_t src, size_t vl)
    __attribute__((target("arch=+v")));

// Float16 conversion (requires RVV + Zvfh half-precision extension)
inline vfloat32m2_t cvt_fp16_to_fp32(vfloat16m1_t src, size_t vl)
    __attribute__((target("arch=+v,+zvfh")));
inline vfloat16m1_t cvt_fp32_to_fp16(vfloat32m2_t src, size_t vl)
    __attribute__((target("arch=+v,+zvfh")));

// Reduction function declarations
void reduce_bf16_buffers(int start_elements, int num_elements,
                         char* to_buffer, char** buffers)
    __attribute__((target("arch=+v")));
void reduce_fp16_buffers(int start_elements, int num_elements,
                         char* to_buffer, char** buffers)
    __attribute__((target("arch=+v,+zvfh")));
void reduce_fp32_buffers(int start_elements, int num_elements,
                         char* to_buffer, char** buffers)
    __attribute__((target("arch=+v")));

void parallel_memcpy(void* to, void* from, size_t n_bytes)
    __attribute__((target("arch=+v")));

Import

// In C++ source files (automatically included by shm.cpp)
#if defined(__riscv)
#define TARGET_RISCV 1
#include "riscv64/shm.h"
#else
#include "x86_64/shm.h"
#endif

// Use generic macros that map to RISC-V intrinsics
void my_reduction_kernel() {
    size_t vl = __riscv_vsetvl_e32m1(num_elements);
    vector_length_in_bytes = vl * sizeof(float);

    for (int i = 0; i < size; i += vector_length_in_bytes) {
        auto val = VLOAD_F32(buffer + i);
        val = VADD_F32(val, other_val);
        VSTORE_F32(output + i, val);
    }
}

I/O Contract

cvt_bf16_to_fp32(src, vl)
Parameter Type Description
src vuint16m1_t BF16 values (16-bit unsigned int vector)
vl size_t Vector length (number of elements)
Returns vfloat32m2_t FP32 values (LMUL=2 for widening)
cvt_fp32_to_bf16(src, vl)
Parameter Type Description
src vfloat32m2_t FP32 values (LMUL=2)
vl size_t Vector length (number of elements)
Returns vuint16m1_t BF16 values with rounding and NaN handling
cvt_fp16_to_fp32(src, vl)
Parameter Type Description
src vfloat16m1_t FP16 values (requires Zvfh)
vl size_t Vector length
Returns vfloat32m2_t FP32 values (widening conversion)
cvt_fp32_to_fp16(src, vl)
Parameter Type Description
src vfloat32m2_t FP32 values
vl size_t Vector length
Returns vfloat16m1_t FP16 values (round-to-odd narrowing)
Macro Definitions
Macro RISC-V Intrinsic Description
VLOAD_U8(X) __riscv_vle8_v_u8m1 Load uint8 vector
VLOAD_U16(X) __riscv_vle16_v_u16m1 Load uint16 vector
VLOAD_F16(X) __riscv_vle16_v_f16m1 Load float16 vector
VLOAD_F32(X) __riscv_vle32_v_f32m1 Load float32 vector
VSTORE_U8(A,B) __riscv_vse8_v_u8m1 Store uint8 vector
VSTORE_U16(A,B) __riscv_vse16_v_u16m1 Store uint16 vector
VSTORE_F16(A,B) __riscv_vse16_v_f16m1 Store float16 vector
VSTORE_F32(A,B) __riscv_vse32_v_f32m1 Store float32 vector
VADD_F32(A,B) __riscv_vfadd_vv_f32m1 Add FP32 vectors (LMUL=1)
VADD_F32_2VL(A,B) __riscv_vfadd_vv_f32m2 Add FP32 vectors (LMUL=2)

Usage Examples

// Example 1: Dynamic vector length determination
void process_buffer(float* data, int num_elements) {
    // Set vector length for float32 elements
    size_t vl = __riscv_vsetvl_e32m1(num_elements);
    vector_length_in_bytes = vl * 4;  // 4 bytes per float

    int main_elements = num_elements - (num_elements % vl);

    // Process aligned portion with vectors
    for (int i = 0; i < main_elements * 4; i += vector_length_in_bytes) {
        auto vec = VLOAD_F32(data + i);
        auto result = VADD_F32(vec, vec);  // Example: double values
        VSTORE_F32(data + i, result);
    }

    // Handle remainder with scalar code
    for (int i = main_elements; i < num_elements; i++) {
        data[i] += data[i];
    }
}

// Example 2: BFloat16 reduction with conversion
void reduce_bf16_data(uint16_t* bf16_data, int num_elements) {
    size_t vl = __riscv_vsetvl_e16m1(num_elements);
    vector_length_in_bytes = vl * 2;  // 2 bytes per BF16

    for (int i = 0; i < num_elements * 2; i += vector_length_in_bytes) {
        // Load as uint16, convert to FP32 for computation
        auto bf16_vec = VLOAD_U16(bf16_data + i);
        auto fp32_vec = CVT_BF16_TO_FP32(bf16_vec);

        // Perform reduction in FP32
        auto sum = VADD_F32_2VL(fp32_vec, other_fp32_vec);

        // Convert back to BF16 for storage
        auto result_bf16 = CVT_FP32_TO_BF16(sum);
        VSTORE_U16(bf16_data + i, result_bf16);
    }
}

// Example 3: Conditional compilation for FP16 support
#ifdef __riscv_zvfh
void reduce_fp16_with_native(float16_t* data, int num_elements) {
    size_t vl = __riscv_vsetvl_e16m1(num_elements);

    for (int i = 0; i < num_elements; i += vl) {
        // Use native FP16 widening conversion
        auto fp16_vec = VLOAD_F16(data + i);
        auto fp32_vec = CVT_FP16_TO_FP32(fp16_vec);

        // Computation in FP32
        auto result = VADD_F32_2VL(fp32_vec, fp32_vec);

        // Native narrowing with round-to-odd
        auto result_fp16 = CVT_FP32_TO_FP16(result);
        VSTORE_F16(data + i, result_fp16);
    }
}
#endif

// Example 4: LMUL (Length Multiplier) handling
void widen_and_reduce(uint16_t* bf16_in, float* fp32_out, int count) {
    size_t vl = __riscv_vsetvl_e16m1(count);

    // Load with LMUL=1 (e16m1), convert to LMUL=2 (f32m2)
    auto bf16_data = __riscv_vle16_v_u16m1(bf16_in, vl);

    // Widen to FP32 (automatically LMUL=2)
    vuint32m2_t widened = __riscv_vwcvtu_x_x_v_u32m2(bf16_data, vl);
    vfloat32m2_t fp32_data = __riscv_vreinterpret_v_u32m2_f32m2(
        __riscv_vsll_vx_u32m2(widened, 16, vl)
    );

    // Store FP32 (LMUL=2 means 2x elements stored)
    __riscv_vse32_v_f32m2(fp32_out, fp32_data, vl);
}

Implementation Details

RISC-V Vector Extension (RVV)

  • VLEN: Vector register bit width (implementation-defined)
  • LMUL: Length multiplier (1, 2, 4, 8 for grouping registers)
  • SEW: Standard Element Width (8, 16, 32, 64 bits)
  • VL: Vector Length (dynamic, set by vsetvl instruction)

BFloat16 Conversion Algorithm

// To FP32: Shift left 16 bits to fill upper half
// [BF16: sign(1) | exp(8) | mantissa(7)] → [FP32: sign(1) | exp(8) | mantissa(23)]
fp32 = (uint32_t)bf16 << 16;

// To BF16: Round-to-nearest-even with tie-breaking
uint32_t lsb = (fp32 >> 16) & 1;
uint32_t rounding_bias = 0x7FFF + lsb;
fp32 += rounding_bias;
bf16 = (uint16_t)(fp32 >> 16);
// Special handling for NaN values

Float16 Conversion

  • Widening: Uses vfwcvt_f_f_v (widens FP16→FP32)
  • Narrowing: Uses vfncvt_rod_f_f_w (round-to-odd FP32→FP16)
  • Requires: Zvfh extension for native FP16 operations

Dynamic Vector Length

// Set VL for processing 'num_elements' of type 'e32m1'
size_t vl = __riscv_vsetvl_e32m1(num_elements);
// vl will be min(num_elements, VLEN/32)
// Actual elements processed depends on hardware VLEN

// Common VLEN values:
// VLEN=128: vl up to 4 floats (16 bytes)
// VLEN=256: vl up to 8 floats (32 bytes)
// VLEN=512: vl up to 16 floats (64 bytes)

Architecture-Specific Optimizations

  • Masked Operations: RVV supports predicated operations (not used here)
  • Segmented Loads: Can load strided or indexed data (future optimization)
  • Chaining: RVV pipelines can chain dependent vector operations
  • Memory Efficiency: Unit-stride loads/stores for optimal bandwidth

Comparison with x86_64

Feature RISC-V RVV x86_64 AVX-512
Vector Length Dynamic (runtime) Fixed (512-bit)
Elements/Op VLEN/SEW 64 bytes / element_size
Conversion Explicit intrinsics Shuffle + shift operations
LMUL Explicit m1, m2, m4, m8 Implicit via register names
Masking Built-in predication Explicit mask registers

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment