Implementation:Deepspeedai DeepSpeed Memory Access Utils
| Knowledge Sources | |
|---|---|
| Domains | Memory, CUDA_Kernels, Performance_Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Low-level CUDA memory access primitives providing cache-policy-aware loads/stores and asynchronous memory copy operations for optimal memory bandwidth.
Description
This comprehensive header provides templated device functions for efficient memory access patterns in CUDA kernels. It supports vectorized loads and stores at various granularities (2, 4, 8, 16 bytes) with explicit cache control policies including CacheAll (L1+L2), CacheGlobal (L2 only), and CacheStreaming (evict-first policy). The implementation uses PTX inline assembly when available for precise control over memory transactions, with portable fallbacks for non-PTX environments. For Ampere and newer architectures, it includes asynchronous memory copy (cp.async) primitives that enable pipelined memory transfers overlapping with computation. The header also provides utilities for shared memory access, address space conversion, pipeline synchronization, and a BufferTracker class for managing multi-stage buffering patterns common in high-performance kernels.
Usage
Use these utilities when implementing performance-critical CUDA kernels that require fine-grained control over memory access patterns. The cache policy selection can significantly impact performance for different access patterns. Asynchronous copy is particularly valuable for implementing software pipelining in kernels with regular memory access patterns like matrix multiplication or convolution.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/includes/memory_access_utils.h
Signature
namespace mem_access {
enum class LoadPolicy { CacheAll, CacheGlobal, CacheStreaming };
enum class StorePolicy { Writeback, CacheGlobal, CacheStreaming };
// Global memory loads with cache control
template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
__device__ __forceinline__ void load_global(void* dst, const void* src);
template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access);
// Shared memory loads (no cache policy)
template <int AccessSize>
__device__ __forceinline__ void load_shared(void* dst, const void* src);
// Global memory stores with cache control
template <int AccessSize, StorePolicy policy = StorePolicy::Writeback>
__device__ __forceinline__ void store_global(void* dst, const void* src);
// Asynchronous copy (Ampere+)
#ifdef ASYNC_COPY_AVAILABLE
template <int AccessSize>
__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl);
__device__ __forceinline__ void memcpy_async_fence();
template <int stages>
__device__ __forceinline__ void memcpy_async_wait();
#endif
// Buffer management
template <int max>
class BufferTracker {
int current_state;
__device__ __forceinline__ int get();
};
__device__ __forceinline__ uint32_t lane_id();
}
Import
#include "csrc/includes/memory_access_utils.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| AccessSize | int (template) | Bytes to transfer: 2, 4, 8, or 16 |
| policy | LoadPolicy/StorePolicy | Cache behavior control |
| dst | void* | Destination pointer (register or memory) |
| src | const void* | Source pointer (global/shared memory) |
| do_access | bool | Predicate for conditional access |
| Output | Type | Description |
|---|---|---|
| data | via dst pointer | Loaded/stored data in target location |
Usage Examples
Vectorized Global Load with Streaming:
__global__ void stream_data_kernel(__half* input, __half* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int offset = idx * 8; // 16 bytes = 8 halfs
if (offset < n) {
__half local_buffer[8];
// Use streaming cache policy for one-time access
mem_access::load_global<16, mem_access::LoadPolicy::CacheStreaming>(
local_buffer, input + offset);
// Process data...
for (int i = 0; i < 8; i++) {
local_buffer[i] = local_buffer[i] * __float2half(2.0f);
}
mem_access::store_global<16>(output + offset, local_buffer);
}
}
Predicated Load for Boundary Handling:
__global__ void safe_load_kernel(float* data, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float4 loaded;
// Safely loads zeros if idx is out of bounds
bool valid = (idx * 4) < n;
mem_access::load_global<16>(&loaded, data + idx * 4, valid);
// Process (zeros for invalid indices)
loaded.x += 1.0f; loaded.y += 1.0f;
loaded.z += 1.0f; loaded.w += 1.0f;
if (valid) {
mem_access::store_global<16>(output + idx * 4, &loaded);
}
}
Asynchronous Pipeline for Matmul:
#ifdef ASYNC_COPY_AVAILABLE
__global__ void pipelined_matmul(const __half* A, const __half* B, __half* C,
int M, int N, int K) {
constexpr int STAGES = 3;
__shared__ __half smem_A[STAGES][TILE_M][TILE_K];
__shared__ __half smem_B[STAGES][TILE_K][TILE_N];
mem_access::BufferTracker<STAGES> tracker;
// Prefetch first stages
for (int stage = 0; stage < STAGES - 1; stage++) {
int k = stage * TILE_K;
mem_access::memcpy_async<16>(&smem_A[stage][ty][tx],
&A[k * M + ty * TILE_M + tx]);
mem_access::memcpy_async<16>(&smem_B[stage][ty][tx],
&B[k * N + ty * TILE_K + tx]);
mem_access::memcpy_async_fence();
}
// Main loop with pipelining
for (int k = 0; k < K; k += TILE_K) {
int stage = tracker.get();
// Wait for this stage's data
mem_access::memcpy_async_wait<STAGES-1>();
__syncthreads();
// Compute with current stage
// ... matmul computation ...
// Prefetch next stage
if (k + STAGES * TILE_K < K) {
int next_k = k + STAGES * TILE_K;
mem_access::memcpy_async<16>(&smem_A[stage][ty][tx],
&A[next_k * M + ty * TILE_M + tx]);
mem_access::memcpy_async<16>(&smem_B[stage][ty][tx],
&B[next_k * N + ty * TILE_K + tx]);
mem_access::memcpy_async_fence();
}
__syncthreads();
}
}
#endif
L2-Only Caching for Reused Data:
__global__ void multi_pass_kernel(const float* weights, const float* input,
float* output, int n) {
int idx = threadIdx.x;
float4 weight;
// Weights reused across blocks - keep in L2 only
mem_access::load_global<16, mem_access::LoadPolicy::CacheGlobal>(
&weight, weights + idx * 4);
for (int pass = 0; pass < 10; pass++) {
float4 data;
// Input streaming through - evict from L1 quickly
mem_access::load_global<16, mem_access::LoadPolicy::CacheStreaming>(
&data, input + (blockIdx.x * 10 + pass) * blockDim.x * 4 + idx * 4);
// Compute...
data.x *= weight.x; data.y *= weight.y;
data.z *= weight.z; data.w *= weight.w;
mem_access::store_global<16>(
output + (blockIdx.x * 10 + pass) * blockDim.x * 4 + idx * 4, &data);
}
}
Related Pages
- Quantization Utils - Uses memory access primitives extensively
- Reduction Utils - Combines with memory ops for efficient reductions
- Conversion Utils - Often used together for type conversions during loads