Principle:Mlc ai Web llm Grammar Constrained Decoding
Overview
Grammar-Constrained Decoding is the algorithm that enforces structural constraints on language model output by applying grammar-derived token masks during the autoregressive decoding process. It integrates a formal grammar parser (GrammarMatcher) into the decode loop so that before sampling each token, invalid continuations are masked out with -infinity logit values. This guarantees every generated sequence is syntactically valid according to the specified grammar, whether that grammar was derived from a JSON Schema, an EBNF definition, or a structural tag specification.
Description
In standard autoregressive decoding, a language model generates tokens one at a time by sampling from a probability distribution over the entire vocabulary. Grammar-constrained decoding adds a filtering step: at each position, a GrammarMatcher determines which tokens are valid continuations given the current parse state, and all other tokens are masked out before sampling.
The process involves three collaborating components:
- GrammarCompiler -- Compiles a user-provided schema (JSON Schema, EBNF, or structural tag definition) into a
CompiledGrammar, which is an optimized internal representation of the grammar's parse automaton. - GrammarMatcher -- A stateful object that tracks the current parse position within the
CompiledGrammar. It exposes two key operations:getNextTokenBitmask()to produce a binary mask over the vocabulary, andacceptToken(tokenId)to advance the parse state. - Decoding loop (in
LLMChatPipeline) -- The standard autoregressive loop that, when grammar-constrained mode is active, calls the GrammarMatcher before each sampling step.
The key insight is that grammar compilation happens once (and is cached), while the per-token bitmask computation and state update happen at each decoding step. The bitmask is applied on GPU for efficiency: the Int32Array bitmask is copied to a GPU tensor, and a TVM function applies it to the logits tensor in place.
Usage
Grammar-constrained decoding is the underlying mechanism activated automatically when response_format is set to "json_object" (with or without a schema), "grammar", or "structural_tag" in a ChatCompletionRequest. Developers do not invoke the grammar matcher directly -- they only need to specify the schema via response_format.
The algorithm activates in these scenarios:
- JSON object mode:
response_format: { type: "json_object", schema: "..." }-- Compiles the JSON Schema into a grammar. - Plain JSON mode:
response_format: { type: "json_object" }-- Uses the built-in JSON grammar (any valid JSON). - EBNF grammar mode:
response_format: { type: "grammar", grammar: "..." }-- Compiles the raw EBNF grammar. - Structural tag mode:
response_format: { type: "structural_tag", structural_tag: {...} }-- Compiles the tag definition.
Theoretical Basis
Algorithm: Per-Token Constrained Sampling
The algorithm operates within the autoregressive decode loop. For each token position t:
- Bitmask computation:
bitmask = grammarMatcher.getNextTokenBitmask()returns anInt32Arrayof lengthceil(vocabSize / 32). Each bit corresponds to one token in the vocabulary: 1 = valid continuation, 0 = invalid. - GPU bitmask application: The bitmask is copied to GPU memory and applied to the logits tensor using
fapplyBitmask(logits, seqIds, bitmask). Invalid tokens receive-infinitylogit values. - Standard sampling: The model samples from the modified logit distribution (top-p, top-k, temperature, etc. are applied after the grammar mask).
- State update:
grammarMatcher.acceptToken(sampledTokenId)advances the parse state to reflect the newly generated token. If the token is somehow invalid (should never happen due to masking), an error is thrown.
Initialization Phase
Grammar initialization runs concurrently with prefilling to hide latency:
- During
prefillStep(), agrammarMatcherInitPromiseis created that runs in parallel with prompt encoding and prefill computation. - The promise performs:
TokenizerInfocreation (if first time) ->GrammarCompilercreation (if first time) -> schema compilation ->GrammarMatchercreation. - The compiled grammar and matcher are cached using a response format key. If the same schema is used in the next request, the matcher is simply reset rather than recompiled.
Formal Properties
- Soundness: Every generated sequence is a valid parse of the grammar. This follows because at each step, only tokens leading to valid partial parses are available for sampling.
- Completeness: If a valid completion exists from the current parse state, the model can reach it (no valid tokens are incorrectly masked).
- Distribution preservation: The conditional distribution over valid tokens is proportional to the model's original distribution. Only the normalization constant changes, so relative probabilities among valid tokens are preserved.
- Termination: The grammar matcher tracks whether the current parse state can reach a terminal state. The decode loop checks
grammarMatcher.isTerminated()to determine if the grammar has been fully satisfied.
Performance Characteristics
- Grammar compilation: One-time cost, measured as
grammar_init_sin usage statistics. Cached across requests with the same schema. - Per-token overhead: The bitmask computation and token acceptance add a small per-token cost, measured as
grammar_per_token_sin usage statistics. - GPU bitmask application: The bitmask is applied on GPU via a TVM-compiled function, making the masking operation itself very fast.
Usage Examples
Conceptual Flow of Grammar-Constrained Decoding
The following pseudocode illustrates the algorithm as it operates within web-llm:
// --- Initialization (runs concurrently with prefill) ---
const tokenizerInfo = await xgr.TokenizerInfo.createTokenizerInfo(
rawTokenTable, postprocMethod, prependSpace, vocabSize, stopTokens
);
const grammarCompiler = await xgr.GrammarCompiler.createGrammarCompiler(tokenizerInfo);
const compiledGrammar = await grammarCompiler.compileJSONSchema(schemaString);
const grammarMatcher = await xgr.GrammarMatcher.createGrammarMatcher(compiledGrammar);
compiledGrammar.dispose(); // compiled grammar can be freed after matcher creation
// --- Decode loop (runs for each output token) ---
while (!stopped) {
// 1. Forward pass: compute logits for next token
const logitsOnGPU = model.forward(inputTokens);
// 2. Get grammar bitmask (CPU)
const bitmask: Int32Array = await grammarMatcher.getNextTokenBitmask();
// 3. Apply bitmask to logits (GPU)
const bitmaskGPU = tvm.empty([1, bitmaskSize], "int32", device).copyFrom(bitmask);
fapplyBitmask(logitsOnGPU, seqIds, bitmaskGPU);
// 4. Sample token from masked logits
const sampledToken = sampleFromLogits(logitsOnGPU, temperature, topP);
// 5. Update grammar state
const accepted = grammarMatcher.acceptToken(sampledToken);
if (!accepted) throw Error("Grammar matcher rejected token");
// 6. Check termination
if (grammarMatcher.isTerminated()) stopped = true;
}
Observing Grammar Performance Metrics
import * as webllm from "@mlc-ai/web-llm";
const engine = await webllm.CreateMLCEngine("Phi-3.5-mini-instruct-q4f16_1-MLC");
const schema = JSON.stringify({
type: "object",
properties: {
name: { type: "string" },
age: { type: "integer" },
},
required: ["name", "age"],
});
const reply = await engine.chat.completions.create({
stream: false,
messages: [{ role: "user", content: "Generate a JSON person record." }],
max_tokens: 64,
response_format: { type: "json_object", schema } as webllm.ResponseFormat,
});
// Access grammar-specific performance metrics
const extra = reply.usage?.extra;
console.log("Grammar init time (s):", extra?.grammar_init_s);
console.log("Grammar per-token time (s):", extra?.grammar_per_token_s);
Related Pages
- Implementation: Grammar Matcher Decoding -- Implementation:Mlc_ai_Web_llm_Grammar_Matcher_Decoding
- Principle: Schema Definition -- The upstream principle for defining the schemas that this algorithm enforces
- Implementation: Response Format -- The interface through which users specify the schema to be enforced
- Heuristic:Mlc_ai_Web_llm_Grammar_Matcher_Reuse
- Heuristic:Mlc_ai_Web_llm_Tokenizer_JSON_Preference