Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Attach Softmax Temperature Pass

From Leeroopedia


Overview

The file python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py defines a TVM compiler pass named AttachSoftmaxWithTemperature that rewrites the standard softmax operation into a numerically stable, two-stage chunked softmax with temperature scaling. This pass generates a Relax function softmax_with_temperature backed by two TIR kernels -- one for computing chunked log-sum-exp values and another for computing the final softmax using those chunks. The pass also handles the special case of temperature equal to zero (greedy decoding).

Location

  • Repository: Mlc_ai_Mlc_llm
  • File: python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py
  • Lines: 274

Pass Architecture

AttachSoftmaxWithTemperature

@tvm.transform.module_pass(opt_level=0, name="AttachSoftmaxWithTemperature")
class AttachSoftmaxWithTemperature:
    def __init__(
        self, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None
    ) -> None:
        self.target = target
        self.metadata = metadata

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        return _Rewriter(mod, self.target, self.metadata).transform()

The pass accepts a compilation target and optional metadata dictionary. The metadata may contain an "active_vocab_size" field, which allows the kernels to skip computation over padding vocabulary entries.

_Rewriter (PyExprMutator)

The _Rewriter class extends PyExprMutator and uses TVM's Relax IR builder to construct the softmax_with_temperature function:

@mutator
class _Rewriter(PyExprMutator):
    def __init__(self, mod, target, metadata=None):
        super().__init__(mod)
        self.mod = mod
        self.target = target
        self.metadata = metadata
        self.chunk_size = 4096
        self.active_vocab_size = self.metadata.get("active_vocab_size") if self.metadata else None

The chunk_size of 4096 controls how the vocabulary dimension is partitioned for the chunked softmax computation.

The transform method constructs a Relax function with the following dataflow:

  1. Reshape the input logits from (batch_size, 1, vocab_size) to (batch_size, vocab_size).
  2. Call chunk_lse -- A TIR function that computes per-chunk log-sum-exp and max values.
  3. Call softmax_with_chunked_sum -- A TIR function that combines the chunked results to produce the final softmax output.
  4. Reshape the output back to the original 3D shape (batch_size, 1, vocab_size).

TIR Kernels

chunk_lse (Chunked Log-Sum-Exp)

The first kernel divides the vocabulary into chunks of size 4096 and computes per-chunk statistics:

Buffer layout:

Buffer Shape Type Description
A (batch_size, vocab_size) float32 Input logits.
temperature (batch_size,) float32 Per-sequence temperature values.
chunked_sum (batch_size, num_chunks) float32 Output: per-chunk log-sum-exp (when temperature > 0) or max count (when temperature == 0).
chunked_max (batch_size, num_chunks) float32 Output: per-chunk maximum values.

Computation stages:

  1. Pad -- Creates a padded view A_pad of shape (batch_size, num_chunks, chunk_size). Values beyond the active vocabulary are set to T.min_value("float32"). When temperature > 0, the logits are divided by temperature during padding.
  2. Max -- Computes the maximum value within each chunk via a reduction.
  3. Sum-Exp -- Computes the sum of exp(x - chunk_max) for each chunk (when temperature > 0), or counts the number of elements equal to the max (when temperature == 0).
  4. Log -- Takes the log of the chunked sum (when temperature > 0) to produce the log-sum-exp. When temperature == 0, the raw count is preserved.

Temperature == 0 (greedy) handling: When the temperature is at or below 1e-5, the kernel switches to argmax mode: instead of computing exponentials, it counts the occurrences of the maximum value per chunk. This enables deterministic greedy decoding with proper tie-breaking.

softmax_with_chunked_sum (Final Softmax)

The second kernel merges the per-chunk results to compute the final softmax probabilities:

Buffer layout:

Buffer Shape Type Description
A (batch_size, vocab_size) float32 Input logits.
temperature (batch_size,) float32 Per-sequence temperature values.
chunked_sum (batch_size, num_chunks) float32 Per-chunk log-sum-exp or count from the first kernel.
chunked_max (batch_size, num_chunks) float32 Per-chunk maximum values from the first kernel.
softmax (batch_size, vocab_size) float32 Output softmax probabilities.

Computation stages:

  1. Global max -- Finds the global maximum across all chunk maxima.
  2. Global sum -- Merges the per-chunk sums, adjusting for the difference between the chunk max and global max:
    • When temperature > 0: sum += exp(chunk_log_sum_exp + chunk_max - global_max)
    • When temperature == 0: sum += (chunk_max == global_max) * chunk_count
  3. Final softmax -- For each vocabulary position:
    • When temperature > 0: softmax[v] = exp(A[v] / temperature - (log(global_sum) + global_max))
    • When temperature == 0: softmax[v] = (A[v] == global_max) / global_count
    • Positions beyond active_vocab_size are set to 0.

GPU Scheduling

For non-LLVM (GPU) targets, the pass applies a custom schedule to the softmax_with_chunked_sum kernel:

def apply_gpu_schedule(target, sch):
    max_threads = get_max_num_threads_per_block(target)
    TX = 32
    TY = max_threads // TX
    unroll_depth = 64

    sch.work_on("softmax_with_chunked_sum")
    l0, l1, l2 = sch.get_loops("log_pad")
    bx = sch.fuse(l0, l1)
    sch.bind(bx, "blockIdx.x")
    unroll, ty, tx = sch.split(l2, [None, TY, TX])
    sch.bind(ty, "threadIdx.y")
    sch.bind(tx, "threadIdx.x")
    sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
    sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)

Key scheduling decisions:

  • The log_pad block (the main output loop) is mapped to a 2D thread configuration: TX=32 (threads in x, matching warp size) and TY=max_threads/32 (threads in y).
  • The batch and chunk dimensions are fused into blockIdx.x.
  • The max and sum_exp reduction blocks are computed at the block level using shared memory (storage_scope="shared") for inter-thread communication.
  • Loop unrolling is set to a depth of 64 with explicit unrolling enabled.

For CPU targets (target.kind.name == "llvm"), the unscheduled TIR functions are returned directly, relying on TVM's default CPU codegen.

Design Notes

  • The two-stage approach is necessary because computing softmax over very large vocabularies (e.g., 128K+ tokens) in a single pass can exceed GPU shared memory limits and cause numerical issues.
  • The source code explicitly notes that a previous attempt to use log2e multiplication for performance was reverted due to numerical instability with large input values, causing softmax outputs to not sum to 1.
  • The active_vocab_size optimization avoids unnecessary computation over padding tokens in models where the actual vocabulary is smaller than the padded tensor dimension.
  • The chunk_lse kernel is left unscheduled (not marked as tir.is_scheduled), while softmax_with_chunked_sum is marked as tir.is_scheduled: 1 since it receives explicit GPU scheduling in the pass.

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment