Implementation:Mlc ai Mlc llm Mistral Model
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, LLM |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Implements the Mistral architecture for causal language modeling within the MLC LLM framework, featuring grouped-query attention with optional sliding window attention and attention sink support.
Description
This module provides the TVM Relax-based implementation of Mistral AI's Mistral model architecture. The implementation closely mirrors the Llama architecture with the addition of sliding window attention capabilities. Key characteristics include:
- Sliding window attention: Supports configurable
sliding_window_sizefor efficient long-context processing. When the sliding window is disabled (set toNoneor-1), the model falls back to standard full attention with a fixed context window. - Attention sink: Configurable
attention_sink_size(default 4) that determines the number of initial tokens always retained in the attention window, even with sliding window enabled. - Grouped-query attention (GQA): Supports fewer key-value heads than query heads for memory-efficient attention computation.
- Fused QKV projection: Single
qkv_projlinear layer for all query, key, and value projections. - SiLU-gated MLP: Fused
gate_up_projwith SiLU activation anddown_proj, identical to the Llama FFN architecture. - RMSNorm: Uses RMSNorm for both input and post-attention normalization, without bias.
- RoPE positional embeddings: Standard rotary position embeddings with configurable
position_embedding_base(rope_theta).
The model consists of MistralModel (embedding + decoder layers + final RMSNorm), wrapped by MistralForCasualLM (note: the class name has a typo "Casual" instead of "Causal" in the source code) which adds the LM head.
Usage
Use this module when compiling Mistral family models (Mistral 7B v0.1, Mistral 7B v0.3, and similar) for deployment with MLC LLM. The model is identified by the mistral model type in configuration files.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/model/mistral/mistral_model.py
Signature
@dataclasses.dataclass
class MistralConfig(ConfigBase):
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_hidden_layers: int
rms_norm_eps: float
vocab_size: int
position_embedding_base: int = 0
num_key_value_heads: int = 0
head_dim: int = 0
context_window_size: int = 0
sliding_window_size: int = 0
prefill_chunk_size: int = 0
attention_sink_size: int = 4
tensor_parallel_shards: int = 1
max_batch_size: int = 1
...
class MistralForCasualLM(nn.Module):
def __init__(self, config: MistralConfig): ...
def embed(self, input_ids: Tensor): ...
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
def batch_decode(self, input_embeds, paged_kv_cache): ...
def batch_verify(self, input_embeds, paged_kv_cache): ...
def create_paged_kv_cache(self, ...): ...
def get_default_spec(self): ...
Import
from mlc_llm.model.mistral.mistral_model import MistralConfig, MistralForCasualLM
I/O Contract
Primary Classes
| Class | Role | Key Characteristics |
|---|---|---|
| MistralConfig | Model configuration | Sliding window and attention sink configuration |
| MistralMLP | Gated feed-forward | Same as LlamaFFN: SiLU-gated gate_up_proj + down_proj |
| MistralAttention | GQA attention | Same as LlamaAttention with sliding window cache support |
| MistralDecoderLayer | Transformer block | Same as LlamaDecoderLayer with RMSNorm |
| MistralModel | Core model | embed_tokens + layers + norm |
| MistralForCasualLM | Top-level model | Standard LM head, sliding window KV cache support |
Forward Methods
| Method | Input | Output |
|---|---|---|
embed |
Tensor[seq_len] (int32) | Tensor[seq_len, hidden_size] |
prefill |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
decode |
Tensor[1, 1, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
batch_prefill |
Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache | (Tensor, PagedKVCache) |
batch_decode |
Tensor[batch_size, 1, hidden_size], PagedKVCache | (Tensor, PagedKVCache) |
batch_verify |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor, PagedKVCache) |
Sliding Window Configuration Logic
The config's __post_init__ determines context window behavior:
- If
sliding_window_sizeis set (not -1),context_window_sizeis set to -1 (unlimited). - If
sliding_window_sizeis -1 (disabled), falls back tomax_position_embeddingsormax_sequence_lengthfrom kwargs. prefill_chunk_sizedefaults tomin(sliding_window_size, context_window_size, 8192).
Usage Examples
# Creating a Mistral 7B v0.1 config with sliding window
config = MistralConfig(
hidden_size=4096,
intermediate_size=14336,
num_attention_heads=32,
num_hidden_layers=32,
rms_norm_eps=1e-5,
vocab_size=32000,
num_key_value_heads=8,
position_embedding_base=10000,
sliding_window_size=4096,
attention_sink_size=4,
)
model = MistralForCasualLM(config)
# Creating a Mistral 7B v0.3 config without sliding window
config_v03 = MistralConfig(
hidden_size=4096,
intermediate_size=14336,
num_attention_heads=32,
num_hidden_layers=32,
rms_norm_eps=1e-5,
vocab_size=32768,
num_key_value_heads=8,
position_embedding_base=1000000,
context_window_size=32768,
)
model_v03 = MistralForCasualLM(config_v03)