Implementation:Deepspeedai DeepSpeed Evoformer Epilogue Rescale
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
An epilogue thread operator that performs online softmax renormalization during attention computation by rescaling outputs using row-wise statistics.
Description
MemoryEfficientAttentionNormalize implements the critical rescaling operation for online/incremental softmax computation in memory-efficient attention. When processing attention in tiles, each tile produces a partial result with its own max and sum. As new tiles are processed, these statistics change, requiring renormalization of previously computed outputs. The operator computes: output = (alpha × accumulator + beta × source) where alpha = 1/s_prime (normalizing factor) and beta = alpha × m_prime (correction for changed maximum). The template is parameterized by isFirst (no source to load) and isLast (apply final normalization) flags, enabling efficient multi-stage accumulation. It accepts per-row statistics arrays (s_prime for sum, m_prime for max correction) and applies them element-wise during the epilogue, leveraging CUTLASS's infrastructure for type conversions and vectorized operations.
Usage
This operator is used in the Evoformer attention forward pass epilogue when accumulating attention-weighted values across key blocks, rescaling the running output as the global max and sum are updated with each new tile.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h
Signature
template <typename ElementOutput_, // Output element type
typename ElementSource_, // Source element type
int Count, // Elements per operation
typename ElementAccumulator_, // Accumulator type
typename ElementCompute_, // Compute type
bool isFirst, // First tile (no source)
bool isLast, // Last tile (final normalize)
typename FragmentAlphaBeta_, // Type for s_prime/m_prime
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using FragmentOutput = Array<ElementOutput, Count>;
using FragmentSource = Array<ElementSource, Count>;
using FragmentAccumulator = Array<ElementAccumulator, Count>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(
FragmentAlphaBeta const& s_prime, // Sum statistics
FragmentAlphaBeta const& m_prime // Max correction
);
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const;
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator) const;
};
Import
#include "csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| Constructor Inputs | ||
| s_prime | FragmentAlphaBeta | Per-row sum of exp(S - max) across all tiles |
| m_prime | FragmentAlphaBeta | Per-row correction exp(old_max - new_max) |
| Operator Inputs | ||
| row | int | Logical row index for indexing statistics |
| accumulator | FragmentAccumulator | New GEMM accumulator values |
| source | FragmentSource | Previous output values (if !isFirst) |
| Output | ||
| FragmentOutput | Array<Element, Count> | Renormalized output fragment |
Usage Examples
// Online softmax computation across key blocks
float mi[kQueriesPerBlock]; // Running max per query row
float si[kQueriesPerBlock]; // Running sum per query row
float mi_prime[kQueriesPerBlock]; // Correction factors
// First key block
using NormalizeFirst = MemoryEfficientAttentionNormalize<
cutlass::half_t, cutlass::half_t, 8, float, float,
true, // isFirst=true (no source to load)
false, // isLast=false (not final)
cutlass::Array<float, kQueriesPerBlock>
>;
NormalizeFirst normalize_first(si, mi_prime);
// Apply in epilogue: output = accumulator / si
// Subsequent key blocks
using NormalizeMid = MemoryEfficientAttentionNormalize<
cutlass::half_t, cutlass::half_t, 8, float, float,
false, // isFirst=false (load previous output)
false, // isLast=false
cutlass::Array<float, kQueriesPerBlock>
>;
NormalizeMid normalize_mid(si, mi_prime);
// Apply in epilogue: output = (accumulator + output_old * mi_prime) / si
// Final key block
using NormalizeLast = MemoryEfficientAttentionNormalize<
/* ... */, false, true /* isLast */>;
// Apply final 1/si normalization