Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Mistral Model

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, LLM
Last Updated 2026-02-09 19:00 GMT

Overview

Implements the Mistral architecture for causal language modeling within the MLC LLM framework, featuring grouped-query attention with optional sliding window attention and attention sink support.

Description

This module provides the TVM Relax-based implementation of Mistral AI's Mistral model architecture. The implementation closely mirrors the Llama architecture with the addition of sliding window attention capabilities. Key characteristics include:

  • Sliding window attention: Supports configurable sliding_window_size for efficient long-context processing. When the sliding window is disabled (set to None or -1), the model falls back to standard full attention with a fixed context window.
  • Attention sink: Configurable attention_sink_size (default 4) that determines the number of initial tokens always retained in the attention window, even with sliding window enabled.
  • Grouped-query attention (GQA): Supports fewer key-value heads than query heads for memory-efficient attention computation.
  • Fused QKV projection: Single qkv_proj linear layer for all query, key, and value projections.
  • SiLU-gated MLP: Fused gate_up_proj with SiLU activation and down_proj, identical to the Llama FFN architecture.
  • RMSNorm: Uses RMSNorm for both input and post-attention normalization, without bias.
  • RoPE positional embeddings: Standard rotary position embeddings with configurable position_embedding_base (rope_theta).

The model consists of MistralModel (embedding + decoder layers + final RMSNorm), wrapped by MistralForCasualLM (note: the class name has a typo "Casual" instead of "Causal" in the source code) which adds the LM head.

Usage

Use this module when compiling Mistral family models (Mistral 7B v0.1, Mistral 7B v0.3, and similar) for deployment with MLC LLM. The model is identified by the mistral model type in configuration files.

Code Reference

Source Location

Signature

@dataclasses.dataclass
class MistralConfig(ConfigBase):
    hidden_size: int
    intermediate_size: int
    num_attention_heads: int
    num_hidden_layers: int
    rms_norm_eps: float
    vocab_size: int
    position_embedding_base: int = 0
    num_key_value_heads: int = 0
    head_dim: int = 0
    context_window_size: int = 0
    sliding_window_size: int = 0
    prefill_chunk_size: int = 0
    attention_sink_size: int = 4
    tensor_parallel_shards: int = 1
    max_batch_size: int = 1
    ...

class MistralForCasualLM(nn.Module):
    def __init__(self, config: MistralConfig): ...
    def embed(self, input_ids: Tensor): ...
    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
    def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
    def batch_decode(self, input_embeds, paged_kv_cache): ...
    def batch_verify(self, input_embeds, paged_kv_cache): ...
    def create_paged_kv_cache(self, ...): ...
    def get_default_spec(self): ...

Import

from mlc_llm.model.mistral.mistral_model import MistralConfig, MistralForCasualLM

I/O Contract

Primary Classes

Class Role Key Characteristics
MistralConfig Model configuration Sliding window and attention sink configuration
MistralMLP Gated feed-forward Same as LlamaFFN: SiLU-gated gate_up_proj + down_proj
MistralAttention GQA attention Same as LlamaAttention with sliding window cache support
MistralDecoderLayer Transformer block Same as LlamaDecoderLayer with RMSNorm
MistralModel Core model embed_tokens + layers + norm
MistralForCasualLM Top-level model Standard LM head, sliding window KV cache support

Forward Methods

Method Input Output
embed Tensor[seq_len] (int32) Tensor[seq_len, hidden_size]
prefill Tensor[1, seq_len, hidden_size], PagedKVCache (Tensor[1, 1, vocab_size], PagedKVCache)
decode Tensor[1, 1, hidden_size], PagedKVCache (Tensor[1, 1, vocab_size], PagedKVCache)
batch_prefill Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache (Tensor, PagedKVCache)
batch_decode Tensor[batch_size, 1, hidden_size], PagedKVCache (Tensor, PagedKVCache)
batch_verify Tensor[1, seq_len, hidden_size], PagedKVCache (Tensor, PagedKVCache)

Sliding Window Configuration Logic

The config's __post_init__ determines context window behavior:

  • If sliding_window_size is set (not -1), context_window_size is set to -1 (unlimited).
  • If sliding_window_size is -1 (disabled), falls back to max_position_embeddings or max_sequence_length from kwargs.
  • prefill_chunk_size defaults to min(sliding_window_size, context_window_size, 8192).

Usage Examples

# Creating a Mistral 7B v0.1 config with sliding window
config = MistralConfig(
    hidden_size=4096,
    intermediate_size=14336,
    num_attention_heads=32,
    num_hidden_layers=32,
    rms_norm_eps=1e-5,
    vocab_size=32000,
    num_key_value_heads=8,
    position_embedding_base=10000,
    sliding_window_size=4096,
    attention_sink_size=4,
)
model = MistralForCasualLM(config)

# Creating a Mistral 7B v0.3 config without sliding window
config_v03 = MistralConfig(
    hidden_size=4096,
    intermediate_size=14336,
    num_attention_heads=32,
    num_hidden_layers=32,
    rms_norm_eps=1e-5,
    vocab_size=32768,
    num_key_value_heads=8,
    position_embedding_base=1000000,
    context_window_size=32768,
)
model_v03 = MistralForCasualLM(config_v03)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment