Implementation:Mlc ai Mlc llm Attach Softmax Temperature Pass
Overview
The file python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py defines a TVM compiler pass named AttachSoftmaxWithTemperature that rewrites the standard softmax operation into a numerically stable, two-stage chunked softmax with temperature scaling. This pass generates a Relax function softmax_with_temperature backed by two TIR kernels -- one for computing chunked log-sum-exp values and another for computing the final softmax using those chunks. The pass also handles the special case of temperature equal to zero (greedy decoding).
Location
- Repository: Mlc_ai_Mlc_llm
- File:
python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py - Lines: 274
Pass Architecture
AttachSoftmaxWithTemperature
@tvm.transform.module_pass(opt_level=0, name="AttachSoftmaxWithTemperature")
class AttachSoftmaxWithTemperature:
def __init__(
self, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None
) -> None:
self.target = target
self.metadata = metadata
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
return _Rewriter(mod, self.target, self.metadata).transform()
The pass accepts a compilation target and optional metadata dictionary. The metadata may contain an "active_vocab_size" field, which allows the kernels to skip computation over padding vocabulary entries.
_Rewriter (PyExprMutator)
The _Rewriter class extends PyExprMutator and uses TVM's Relax IR builder to construct the softmax_with_temperature function:
@mutator
class _Rewriter(PyExprMutator):
def __init__(self, mod, target, metadata=None):
super().__init__(mod)
self.mod = mod
self.target = target
self.metadata = metadata
self.chunk_size = 4096
self.active_vocab_size = self.metadata.get("active_vocab_size") if self.metadata else None
The chunk_size of 4096 controls how the vocabulary dimension is partitioned for the chunked softmax computation.
The transform method constructs a Relax function with the following dataflow:
- Reshape the input logits from
(batch_size, 1, vocab_size)to(batch_size, vocab_size). - Call
chunk_lse-- A TIR function that computes per-chunk log-sum-exp and max values. - Call
softmax_with_chunked_sum-- A TIR function that combines the chunked results to produce the final softmax output. - Reshape the output back to the original 3D shape
(batch_size, 1, vocab_size).
TIR Kernels
chunk_lse (Chunked Log-Sum-Exp)
The first kernel divides the vocabulary into chunks of size 4096 and computes per-chunk statistics:
Buffer layout:
| Buffer | Shape | Type | Description |
|---|---|---|---|
A |
(batch_size, vocab_size) |
float32 |
Input logits. |
temperature |
(batch_size,) |
float32 |
Per-sequence temperature values. |
chunked_sum |
(batch_size, num_chunks) |
float32 |
Output: per-chunk log-sum-exp (when temperature > 0) or max count (when temperature == 0). |
chunked_max |
(batch_size, num_chunks) |
float32 |
Output: per-chunk maximum values. |
Computation stages:
- Pad -- Creates a padded view
A_padof shape(batch_size, num_chunks, chunk_size). Values beyond the active vocabulary are set toT.min_value("float32"). When temperature > 0, the logits are divided by temperature during padding. - Max -- Computes the maximum value within each chunk via a reduction.
- Sum-Exp -- Computes the sum of
exp(x - chunk_max)for each chunk (when temperature > 0), or counts the number of elements equal to the max (when temperature == 0). - Log -- Takes the log of the chunked sum (when temperature > 0) to produce the log-sum-exp. When temperature == 0, the raw count is preserved.
Temperature == 0 (greedy) handling: When the temperature is at or below 1e-5, the kernel switches to argmax mode: instead of computing exponentials, it counts the occurrences of the maximum value per chunk. This enables deterministic greedy decoding with proper tie-breaking.
softmax_with_chunked_sum (Final Softmax)
The second kernel merges the per-chunk results to compute the final softmax probabilities:
Buffer layout:
| Buffer | Shape | Type | Description |
|---|---|---|---|
A |
(batch_size, vocab_size) |
float32 |
Input logits. |
temperature |
(batch_size,) |
float32 |
Per-sequence temperature values. |
chunked_sum |
(batch_size, num_chunks) |
float32 |
Per-chunk log-sum-exp or count from the first kernel. |
chunked_max |
(batch_size, num_chunks) |
float32 |
Per-chunk maximum values from the first kernel. |
softmax |
(batch_size, vocab_size) |
float32 |
Output softmax probabilities. |
Computation stages:
- Global max -- Finds the global maximum across all chunk maxima.
- Global sum -- Merges the per-chunk sums, adjusting for the difference between the chunk max and global max:
- When temperature > 0:
sum += exp(chunk_log_sum_exp + chunk_max - global_max) - When temperature == 0:
sum += (chunk_max == global_max) * chunk_count
- When temperature > 0:
- Final softmax -- For each vocabulary position:
- When temperature > 0:
softmax[v] = exp(A[v] / temperature - (log(global_sum) + global_max)) - When temperature == 0:
softmax[v] = (A[v] == global_max) / global_count - Positions beyond
active_vocab_sizeare set to 0.
- When temperature > 0:
GPU Scheduling
For non-LLVM (GPU) targets, the pass applies a custom schedule to the softmax_with_chunked_sum kernel:
def apply_gpu_schedule(target, sch):
max_threads = get_max_num_threads_per_block(target)
TX = 32
TY = max_threads // TX
unroll_depth = 64
sch.work_on("softmax_with_chunked_sum")
l0, l1, l2 = sch.get_loops("log_pad")
bx = sch.fuse(l0, l1)
sch.bind(bx, "blockIdx.x")
unroll, ty, tx = sch.split(l2, [None, TY, TX])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)
Key scheduling decisions:
- The
log_padblock (the main output loop) is mapped to a 2D thread configuration:TX=32(threads in x, matching warp size) andTY=max_threads/32(threads in y). - The batch and chunk dimensions are fused into
blockIdx.x. - The
maxandsum_expreduction blocks are computed at the block level using shared memory (storage_scope="shared") for inter-thread communication. - Loop unrolling is set to a depth of 64 with explicit unrolling enabled.
For CPU targets (target.kind.name == "llvm"), the unscheduled TIR functions are returned directly, relying on TVM's default CPU codegen.
Design Notes
- The two-stage approach is necessary because computing softmax over very large vocabularies (e.g., 128K+ tokens) in a single pass can exceed GPU shared memory limits and cause numerical issues.
- The source code explicitly notes that a previous attempt to use
log2emultiplication for performance was reverted due to numerical instability with large input values, causing softmax outputs to not sum to 1. - The
active_vocab_sizeoptimization avoids unnecessary computation over padding tokens in models where the actual vocabulary is smaller than the padded tensor dimension. - The
chunk_lsekernel is left unscheduled (not marked astir.is_scheduled), whilesoftmax_with_chunked_sumis marked astir.is_scheduled: 1since it receives explicit GPU scheduling in the pass.