Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer Epilogue Rescale

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

An epilogue thread operator that performs online softmax renormalization during attention computation by rescaling outputs using row-wise statistics.

Description

MemoryEfficientAttentionNormalize implements the critical rescaling operation for online/incremental softmax computation in memory-efficient attention. When processing attention in tiles, each tile produces a partial result with its own max and sum. As new tiles are processed, these statistics change, requiring renormalization of previously computed outputs. The operator computes: output = (alpha × accumulator + beta × source) where alpha = 1/s_prime (normalizing factor) and beta = alpha × m_prime (correction for changed maximum). The template is parameterized by isFirst (no source to load) and isLast (apply final normalization) flags, enabling efficient multi-stage accumulation. It accepts per-row statistics arrays (s_prime for sum, m_prime for max correction) and applies them element-wise during the epilogue, leveraging CUTLASS's infrastructure for type conversions and vectorized operations.

Usage

This operator is used in the Evoformer attention forward pass epilogue when accumulating attention-weighted values across key blocks, rescaling the running output as the global max and sum are updated with each new tile.

Code Reference

Source Location

Signature

template <typename ElementOutput_,      // Output element type
          typename ElementSource_,      // Source element type
          int Count,                    // Elements per operation
          typename ElementAccumulator_, // Accumulator type
          typename ElementCompute_,     // Compute type
          bool isFirst,                 // First tile (no source)
          bool isLast,                  // Last tile (final normalize)
          typename FragmentAlphaBeta_,  // Type for s_prime/m_prime
          FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
    using FragmentOutput = Array<ElementOutput, Count>;
    using FragmentSource = Array<ElementSource, Count>;
    using FragmentAccumulator = Array<ElementAccumulator, Count>;
    using FragmentAlphaBeta = FragmentAlphaBeta_;

    CUTLASS_HOST_DEVICE
    MemoryEfficientAttentionNormalize(
        FragmentAlphaBeta const& s_prime,  // Sum statistics
        FragmentAlphaBeta const& m_prime   // Max correction
    );

    CUTLASS_HOST_DEVICE
    FragmentOutput operator()(int row,
                             FragmentAccumulator const& accumulator,
                             FragmentSource const& source) const;

    CUTLASS_HOST_DEVICE
    FragmentOutput operator()(int row,
                             FragmentAccumulator const& accumulator) const;
};

Import

#include "csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h"

I/O Contract

Parameter Type Description
Constructor Inputs
s_prime FragmentAlphaBeta Per-row sum of exp(S - max) across all tiles
m_prime FragmentAlphaBeta Per-row correction exp(old_max - new_max)
Operator Inputs
row int Logical row index for indexing statistics
accumulator FragmentAccumulator New GEMM accumulator values
source FragmentSource Previous output values (if !isFirst)
Output
FragmentOutput Array<Element, Count> Renormalized output fragment

Usage Examples

// Online softmax computation across key blocks
float mi[kQueriesPerBlock];      // Running max per query row
float si[kQueriesPerBlock];      // Running sum per query row
float mi_prime[kQueriesPerBlock]; // Correction factors

// First key block
using NormalizeFirst = MemoryEfficientAttentionNormalize<
    cutlass::half_t, cutlass::half_t, 8, float, float,
    true,   // isFirst=true (no source to load)
    false,  // isLast=false (not final)
    cutlass::Array<float, kQueriesPerBlock>
>;

NormalizeFirst normalize_first(si, mi_prime);
// Apply in epilogue: output = accumulator / si

// Subsequent key blocks
using NormalizeMid = MemoryEfficientAttentionNormalize<
    cutlass::half_t, cutlass::half_t, 8, float, float,
    false,  // isFirst=false (load previous output)
    false,  // isLast=false
    cutlass::Array<float, kQueriesPerBlock>
>;

NormalizeMid normalize_mid(si, mi_prime);
// Apply in epilogue: output = (accumulator + output_old * mi_prime) / si

// Final key block
using NormalizeLast = MemoryEfficientAttentionNormalize<
    /* ... */, false, true /* isLast */>;
// Apply final 1/si normalization

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment