Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Attach Logit Processor Pass

From Leeroopedia


Overview

The file python/mlc_llm/compiler_pass/attach_logit_processor.py defines a TVM compiler pass named AttachLogitProcessFunc that attaches three logit-processing TIR (Tensor IR) functions to the compiled model's IRModule. These functions are used at serving time to apply logit biases, repetition/frequency/presence penalties, and vocabulary bitmask constraints directly on the GPU or CPU.

Location

  • Repository: Mlc_ai_Mlc_llm
  • File: python/mlc_llm/compiler_pass/attach_logit_processor.py
  • Lines: 285

Pass Architecture

The pass is registered as a TVM module-level transformation at optimization level 0:

@tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc")
class AttachLogitProcessFunc:
    def __init__(self, target: tvm.target.Target):
        self.target = target

    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        mod = mod.clone()
        if str(self.target.kind) == "llvm":
            mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu()
            mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu()
            mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu()
        else:
            mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
            mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
            mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
        return mod

The pass branches on the target kind:

  • llvm targets (CPU) -- Uses sequential loop-based implementations.
  • Other targets (GPU) -- Uses thread-bound implementations with configurable block sizes.

Attached TIR Functions

apply_logit_bias_inplace

Adds per-token logit biases to the logit tensor. Used to implement features such as guided generation or token suppression.

Buffer layout:

Buffer Shape Type Description
logits (batch_size, vocab_size) float32 The logit tensor to modify in place.
pos2seq_id (num_token,) int32 Maps each token position to its sequence ID in the batch.
token_ids (num_token,) int32 The vocabulary token IDs to apply biases to.
logit_bias (num_token,) float32 The bias values to add.

Computation: For each token position i, the function executes:

logits[pos2seq_id[i], token_ids[i]] += logit_bias[i]

The GPU variant uses thread binding with up to 1024 threads per block (clamped to the target's maximum), processing tokens in parallel across CUDA blocks and threads.

apply_penalty_inplace

Applies three types of penalties to discourage token repetition: presence penalty, frequency penalty, and repetition penalty.

Buffer layout:

Buffer Shape Type Description
logits (batch_size, vocab_size) float32 The logit tensor to modify in place.
seq_ids (num_seq,) int32 Sequence IDs within the batch.
pos2seq_id (num_token,) int32 Maps token positions to sequence indices.
token_ids (num_token,) int32 Token IDs that have appeared in the context.
token_cnt (num_token,) int32 The number of times each token has appeared.
penalties (num_seq, 3) float32 Per-sequence penalty values: column 0 = presence penalty, column 1 = frequency penalty, column 2 = repetition penalty.

Computation: For each token position, two operations are applied sequentially:

  1. Presence + Frequency penalty: Subtracts presence_penalty + token_count * frequency_penalty from the logit.
  2. Repetition penalty: If the adjusted logit is negative, it is multiplied by the repetition penalty. If positive, it is divided by the repetition penalty.
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (
    penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]
)
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(
    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0),
    logits[...] * penalties[pos2seq_id[vp], 2],
    logits[...] / penalties[pos2seq_id[vp], 2],
)

apply_bitmask_inplace

Applies a vocabulary bitmask to enforce constrained decoding (e.g., grammar-guided generation). Tokens not permitted by the bitmask have their logits set to the minimum representable float32 value.

Buffer layout:

Buffer Shape Type Description
logits (batch_size, vocab_size) float32 The logit tensor to modify in place.
seq_ids (num_seq,) int32 Sequence IDs within the batch.
bitmask (batch_size, (vocab_size + 31) // 32) int32 A packed bitmask where bit v of word v // 32 indicates whether token v is allowed (1) or masked (0).

Computation: For each sequence and vocabulary position, the function checks the corresponding bit:

logits[seq_ids[vs], vv] = T.if_then_else(
    (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1,
    logits[seq_ids[vs], vv],
    T.min_value("float32"),
)

If the bit is 0, the logit is replaced with T.min_value("float32"), effectively forcing the token's probability to zero after softmax.

GPU Thread Configuration

All three GPU function variants share the same threading strategy:

tx = 1024  # default
max_num_threads_per_block = get_max_num_threads_per_block(target)
tx = min(tx, max_num_threads_per_block)
check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1)

The number of threads per block defaults to 1024 but is clamped to the target's hardware limit. Work is distributed across blocks via T.thread_binding with "blockIdx.x" and "threadIdx.x", and out-of-bounds threads are guarded by T.where conditions.

Design Notes

  • All functions use T.func_attr({"tir.is_scheduled": True}) to indicate that these TIR functions are fully scheduled and should not be further transformed by TVM's auto-scheduler.
  • The T.func_attr({"tir.noalias": True}) annotation indicates that all buffer arguments do not alias, enabling more aggressive compiler optimizations.
  • The pass clones the IRModule before modification (mod = mod.clone()) to avoid mutating the input module, following TVM's immutability convention for compiler passes.
  • All dynamic dimensions (batch_size, vocab_size, num_token, num_seq) are declared as TIR size variables, enabling the functions to work with varying input sizes at runtime.

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment