Implementation:Mlc ai Mlc llm Attach Logit Processor Pass
Overview
The file python/mlc_llm/compiler_pass/attach_logit_processor.py defines a TVM compiler pass named AttachLogitProcessFunc that attaches three logit-processing TIR (Tensor IR) functions to the compiled model's IRModule. These functions are used at serving time to apply logit biases, repetition/frequency/presence penalties, and vocabulary bitmask constraints directly on the GPU or CPU.
Location
- Repository: Mlc_ai_Mlc_llm
- File:
python/mlc_llm/compiler_pass/attach_logit_processor.py - Lines: 285
Pass Architecture
The pass is registered as a TVM module-level transformation at optimization level 0:
@tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc")
class AttachLogitProcessFunc:
def __init__(self, target: tvm.target.Target):
self.target = target
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
mod = mod.clone()
if str(self.target.kind) == "llvm":
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu()
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu()
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu()
else:
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
return mod
The pass branches on the target kind:
llvmtargets (CPU) -- Uses sequential loop-based implementations.- Other targets (GPU) -- Uses thread-bound implementations with configurable block sizes.
Attached TIR Functions
apply_logit_bias_inplace
Adds per-token logit biases to the logit tensor. Used to implement features such as guided generation or token suppression.
Buffer layout:
| Buffer | Shape | Type | Description |
|---|---|---|---|
logits |
(batch_size, vocab_size) |
float32 |
The logit tensor to modify in place. |
pos2seq_id |
(num_token,) |
int32 |
Maps each token position to its sequence ID in the batch. |
token_ids |
(num_token,) |
int32 |
The vocabulary token IDs to apply biases to. |
logit_bias |
(num_token,) |
float32 |
The bias values to add. |
Computation: For each token position i, the function executes:
logits[pos2seq_id[i], token_ids[i]] += logit_bias[i]
The GPU variant uses thread binding with up to 1024 threads per block (clamped to the target's maximum), processing tokens in parallel across CUDA blocks and threads.
apply_penalty_inplace
Applies three types of penalties to discourage token repetition: presence penalty, frequency penalty, and repetition penalty.
Buffer layout:
| Buffer | Shape | Type | Description |
|---|---|---|---|
logits |
(batch_size, vocab_size) |
float32 |
The logit tensor to modify in place. |
seq_ids |
(num_seq,) |
int32 |
Sequence IDs within the batch. |
pos2seq_id |
(num_token,) |
int32 |
Maps token positions to sequence indices. |
token_ids |
(num_token,) |
int32 |
Token IDs that have appeared in the context. |
token_cnt |
(num_token,) |
int32 |
The number of times each token has appeared. |
penalties |
(num_seq, 3) |
float32 |
Per-sequence penalty values: column 0 = presence penalty, column 1 = frequency penalty, column 2 = repetition penalty. |
Computation: For each token position, two operations are applied sequentially:
- Presence + Frequency penalty: Subtracts
presence_penalty + token_count * frequency_penaltyfrom the logit. - Repetition penalty: If the adjusted logit is negative, it is multiplied by the repetition penalty. If positive, it is divided by the repetition penalty.
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (
penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]
)
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0),
logits[...] * penalties[pos2seq_id[vp], 2],
logits[...] / penalties[pos2seq_id[vp], 2],
)
apply_bitmask_inplace
Applies a vocabulary bitmask to enforce constrained decoding (e.g., grammar-guided generation). Tokens not permitted by the bitmask have their logits set to the minimum representable float32 value.
Buffer layout:
| Buffer | Shape | Type | Description |
|---|---|---|---|
logits |
(batch_size, vocab_size) |
float32 |
The logit tensor to modify in place. |
seq_ids |
(num_seq,) |
int32 |
Sequence IDs within the batch. |
bitmask |
(batch_size, (vocab_size + 31) // 32) |
int32 |
A packed bitmask where bit v of word v // 32 indicates whether token v is allowed (1) or masked (0).
|
Computation: For each sequence and vocabulary position, the function checks the corresponding bit:
logits[seq_ids[vs], vv] = T.if_then_else(
(bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1,
logits[seq_ids[vs], vv],
T.min_value("float32"),
)
If the bit is 0, the logit is replaced with T.min_value("float32"), effectively forcing the token's probability to zero after softmax.
GPU Thread Configuration
All three GPU function variants share the same threading strategy:
tx = 1024 # default
max_num_threads_per_block = get_max_num_threads_per_block(target)
tx = min(tx, max_num_threads_per_block)
check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1)
The number of threads per block defaults to 1024 but is clamped to the target's hardware limit. Work is distributed across blocks via T.thread_binding with "blockIdx.x" and "threadIdx.x", and out-of-bounds threads are guarded by T.where conditions.
Design Notes
- All functions use
T.func_attr({"tir.is_scheduled": True})to indicate that these TIR functions are fully scheduled and should not be further transformed by TVM's auto-scheduler. - The
T.func_attr({"tir.noalias": True})annotation indicates that all buffer arguments do not alias, enabling more aggressive compiler optimizations. - The pass clones the
IRModulebefore modification (mod = mod.clone()) to avoid mutating the input module, following TVM's immutability convention for compiler passes. - All dynamic dimensions (
batch_size,vocab_size,num_token,num_seq) are declared as TIR size variables, enabling the functions to work with varying input sizes at runtime.