Implementation:Deepspeedai DeepSpeed Evoformer Bias Broadcast
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Template-based bias loading and broadcasting utilities for efficiently adding position and pair embeddings to attention scores with on-the-fly dimension expansion.
Description
BroadcastA and BroadcastB implement efficient loading and broadcasting of bias tensors with different dimensionality patterns. BroadcastA loads a row vector from shape [B, N, 1, 1, L] and broadcasts it to an LxL tile by replicating the vector across rows, supporting row-wise biases like relative position embeddings. BroadcastB loads a full LxL matrix from shape [B, 1, H, L, L] that is shared across the N dimension, supporting pair-wise biases like distance matrices or pairwise features. BroadcastNoLoad is a no-op placeholder when biases are not used. AttentionBiasEpilogue coordinates the loading of up to two bias sources, accumulating them into shared memory before attention computation. The implementation uses CUTLASS PredicatedTileIterator with affine rank layouts to handle the complex striding patterns efficiently.
Usage
These broadcast utilities are instantiated in Evoformer attention forward/backward kernels when bias terms need to be added to attention logits, supporting the position-aware and pair-aware attention mechanisms in AlphaFold2's Evoformer architecture.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h
Signature
// No-op broadcast (no bias)
template <typename ThreadMap, typename Shape, typename scalar_t>
struct BroadcastNoLoad {
static const bool kEnable = false;
using Fragment = cutlass::Array<scalar_t, ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
CUTLASS_DEVICE static void load(Fragment& frag, scalar_t* ptr,
int thread_id, const MatrixCoord& extent, int stride) {}
};
// Row vector broadcast: [B, N, 1, 1, L] -> [L, L]
template <typename ThreadMap, typename Shape, typename scalar_t>
struct BroadcastA : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
static const bool kEnable = true;
using layout = cutlass::layout::AffineRank2RowMajor;
CUTLASS_DEVICE static void load(Fragment& frag, scalar_t* ptr,
int thread_id, const MatrixCoord& extent, int stride);
CUTLASS_DEVICE static scalar_t* advance(scalar_t* ptr, int B_id, int N_id,
int H_id, int strideB, int strideN, int strideH);
};
// Matrix broadcast: [B, 1, H, L, L] -> [L, L]
template <typename ThreadMap, typename Shape, typename scalar_t>
struct BroadcastB : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
static const bool kEnable = true;
using layout = cutlass::layout::RowMajor;
CUTLASS_DEVICE static void load(Fragment& frag, scalar_t* ptr,
int thread_id, const MatrixCoord& extent, int stride);
CUTLASS_DEVICE static scalar_t* advance(scalar_t* ptr, int B_id, int N_id,
int H_id, int strideB, int strideN, int strideH);
};
// Epilogue coordinator for loading and combining biases
template <typename Shape, typename scalar_t, int kThreads,
template <typename, typename, typename> class Broadcast1_,
template <typename, typename, typename> class Broadcast2_>
struct AttentionBiasEpilogue {
using Broadcast1 = Broadcast1_<ThreadMap, Shape, scalar_t>;
using Broadcast2 = Broadcast2_<ThreadMap, Shape, scalar_t>;
CUTLASS_DEVICE void operator()(const Ref& ref, scalar_t* ptr1, scalar_t* ptr2,
int thread_id, const MatrixCoord& extent, int stride);
};
Import
#include "csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h"
I/O Contract
| Component | Input Shape | Broadcast To | Description |
|---|---|---|---|
| BroadcastNoLoad | N/A | N/A | No-op when no bias is used |
| BroadcastA | [B, N, 1, 1, L] | [L, L] | Row vector repeated for each row (row-wise bias) |
| BroadcastB | [B, 1, H, L, L] | [L, L] | Full matrix shared across N (pair-wise bias) |
| AttentionBiasEpilogue | ptr1, ptr2 | Shared memory | Loads and combines up to 2 bias sources |
Usage Examples
// Configure attention with row-wise position bias
using AttentionWithBias = AttentionKernel<
cutlass::half_t,
cutlass::arch::Sm80,
true, // Aligned
64, 64, // Block sizes
false, // Multi-value iteration
true, // Support bias
BroadcastA, // Row-wise bias from [B, N, 1, 1, L]
BroadcastNoLoad // No second bias
>;
// In kernel, load biases before computing attention
__shared__ float smem_bias[64][64 + 4];
using BiasEpilogue = AttentionBiasEpilogue<
cutlass::MatrixShape<64, 64>,
cutlass::half_t,
256, // Number of threads
BroadcastA, // First bias type
BroadcastNoLoad // Second bias type
>;
BiasEpilogue bias_loader;
bias_loader(
{smem_bias[0], 64 + 4}, // Shared memory reference
bias1_ptr, // Row bias pointer [B, N, 1, 1, L]
nullptr, // No second bias
thread_id,
{64, 64}, // Tile extent
seq_len // Stride
);
__syncthreads();
// smem_bias now contains broadcasted bias, add to attention scores