Implementation:Deepspeedai DeepSpeed XPU Multi Tensor Apply
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep Learning, XPU Computing, SYCL |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
SYCL-based multi-tensor processing framework for applying operations across multiple tensor groups on Intel XPU devices.
Description
This file provides the infrastructure for efficiently applying functors (like optimizer steps) across multiple tensor lists on Intel XPU devices using SYCL. Adapted from NVIDIA's apex library, it implements a sophisticated batching system that groups tensors into chunks and blocks to maximize GPU utilization while respecting kernel argument size limits. The framework uses TensorListMetadata structures to organize tensor addresses, sizes, and block mappings, and provides a generic multi_tensor_apply_kernel wrapper that handles parameter pack expansion for arbitrary functor signatures.
Usage
Use this framework when implementing custom operations that need to process multiple tensor groups simultaneously on Intel XPU devices, such as fused optimizer kernels.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/xpu/adam/multi_tensor_apply.dp.hpp
Signature
template <int n>
struct TensorListMetadata {
void* addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
class multi_tensor_apply_kernel {
public:
multi_tensor_apply_kernel(int chunk_size,
volatile int* noop_flag,
T tl,
U callable,
ArgTypes... args);
void operator()(sycl::nd_item<3>) const;
};
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args);
Import
#include <ATen/ATen.h>
#include <sycl/sycl.hpp>
#include <c10/xpu/XPUStream.h>
I/O Contract
multi_tensor_apply Parameters
| Parameter | Type | Description |
|---|---|---|
| depth | int (template) | Number of tensor groups (e.g., 4 for [g, p, m, v]) |
| block_size | int | Number of threads per block (typically 512) |
| chunk_size | int | Size of chunks for processing (typically 65536) |
| noop_flag | const at::Tensor& | Flag tensor to skip processing if non-zero |
| tensor_lists | const std::vector<std::vector<at::Tensor>>& | Vector of tensor groups |
| callable | T | Functor object to apply to tensors |
| args | ArgTypes... | Additional arguments passed to callable |
TensorListMetadata Fields
| Field | Type | Description |
|---|---|---|
| addresses | void*[n][] | Pointers to tensor data for each group |
| sizes | int[] | Number of elements in each tensor |
| block_to_tensor | unsigned char[] | Mapping from block index to tensor index |
| block_to_chunk | int[] | Mapping from block index to chunk within tensor |
| start_tensor_this_launch | int | Starting tensor index for this kernel launch |
Depth Limits
| Depth | Max Tensors | Max Blocks | Description |
|---|---|---|---|
| 1 | 110 | 320 | Single tensor group |
| 2 | 64 | 320 | Two tensor groups |
| 3 | 48 | 320 | Three tensor groups |
| 4 | 36 | 320 | Four tensor groups (typical for optimizers) |
| 5 | 30 | 320 | Five tensor groups |
Usage Examples
#include "multi_tensor_apply.dp.hpp"
// Define a simple functor
template <typename T>
struct ScaleFunctor {
void operator()(int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>& tl,
float scale_factor) {
auto item = sycl::ext::oneapi::experimental::this_nd_item<3>();
int tensor_loc = tl.block_to_tensor[item.get_group(2)];
int chunk_idx = tl.block_to_chunk[item.get_group(2)];
int n = tl.sizes[tensor_loc];
T* data = (T*)tl.addresses[0][tensor_loc];
data += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
for (int i = item.get_local_id(2); i < n && i < chunk_size;
i += item.get_local_range(2)) {
data[i] = data[i] * scale_factor;
}
}
};
// Apply functor to tensor list
std::vector<std::vector<at::Tensor>> tensor_lists = {params_list};
at::Tensor noop_flag = at::zeros({1}, at::kInt).to(at::kXPU);
multi_tensor_apply<1>(
/* block_size = */ 512,
/* chunk_size = */ 65536,
noop_flag,
tensor_lists,
ScaleFunctor<float>(),
/* scale_factor = */ 2.0f
);