Implementation:Mlc ai Mlc llm Ministral3 Model
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, LLM |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Implements the Ministral 3 (Mistral 3) architecture for conditional generation within the MLC LLM framework, supporting YaRN-based RoPE scaling, configurable activation functions, FP8 quantization configuration, and optional tied word embeddings.
Description
This module provides the TVM Relax-based implementation of the Ministral 3 architecture, which is the text backbone used in Mistral 3 multimodal models. It extends the Llama-style decoder architecture with several enhancements:
- YaRN RoPE scaling: Supports the YaRN (Yet another RoPE extensioN) method for extending context windows. The attention module computes a modified softmax scale using
yarn_get_sm_scale()whenmscale_all_dimis provided inrope_parameters. - Configurable activation functions: Supports multiple activation functions (silu, gelu, relu, swish, gelu_new) via the
ACT2FNmapping dictionary, defaulting to SiLU. - FP8 quantization support: The config handles
quantization_configfrom HuggingFace, supporting FP8 static quantization with configurableweight_block_size(default 128x128). - Module quantization exclusion: Supports
modules_to_not_convertto mark specific modules withno_quantization = True, allowing selective quantization. - Tied word embeddings: Uses a custom
Ministral3Embeddingclass that supports weight transposition for shared embedding/lm_head vialm_head_forward. - Nested text_config support: The
from_dictclass method merges top-level and nestedtext_configfields for compatibility with multimodal model configurations. - Sliding window attention: Configurable via
sliding_window_sizewith proper fallback logic for context window determination.
The top-level class is Mistral3ForConditionalGeneration (note the naming follows the HuggingFace convention for the multimodal variant), which wraps Ministral3Model containing the embedding, decoder layers, and final RMSNorm.
Usage
Use this module when compiling Ministral 3 / Mistral 3 family models for deployment with MLC LLM. The model is identified by the ministral3 model type in configuration files and uses the Mistral3ForConditionalGeneration architecture name.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/model/ministral3/ministral3_model.py
Signature
@dataclasses.dataclass
class Ministral3Config(ConfigBase):
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_hidden_layers: int
rms_norm_eps: float
vocab_size: int
attention_sink_size: int = 0
context_window_size: int = 0
head_dim: int = 0
hidden_act: str = "silu"
num_key_value_heads: int = 0
position_embedding_base: int = 0
rope_parameters: Optional[Dict[str, Any]] = None
sliding_window_size: int = 0
tensor_parallel_shards: int = 1
tie_word_embeddings: bool = False
weight_block_size: Optional[Tuple[int, int]] = None
modules_to_not_convert: Tuple[str, ...] = ...
...
class Mistral3ForConditionalGeneration(nn.Module):
def __init__(self, config: Ministral3Config): ...
def embed(self, input_ids: Tensor): ...
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
def batch_decode(self, input_embeds, paged_kv_cache): ...
def batch_verify(self, input_embeds, paged_kv_cache): ...
def create_paged_kv_cache(self, ...): ...
def get_default_spec(self): ...
Import
from mlc_llm.model.ministral3.ministral3_model import Ministral3Config, Mistral3ForConditionalGeneration
I/O Contract
Primary Classes
| Class | Role | Key Characteristics |
|---|---|---|
| Ministral3Config | Model configuration | YaRN rope_parameters, FP8 quantization support, modules_to_not_convert |
| Ministral3Embedding | Shared embedding | lm_head_forward via weight transposition |
| Ministral3MLP | Gated feed-forward | Configurable activation via ACT2FN dict, gate_up_proj + down_proj |
| Ministral3Attention | GQA attention | YaRN softmax scale modification, fused qkv_proj |
| Ministral3DecoderLayer | Transformer block | Pre-norm with RMSNorm, residual connections with tensor parallel allreduce |
| Ministral3Model | Core model | embed_tokens + layers + norm |
| Mistral3ForConditionalGeneration | Top-level model | Supports tied embeddings, selective quantization exclusion |
Forward Methods
| Method | Input | Output |
|---|---|---|
embed |
Tensor[seq_len] (int32) | Tensor[seq_len, hidden_size] |
prefill |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
decode |
Tensor[1, 1, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
batch_prefill |
Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache | (Tensor, PagedKVCache) |
batch_decode |
Tensor[batch_size, 1, hidden_size], PagedKVCache | (Tensor, PagedKVCache) |
YaRN Scale Computation
def yarn_get_sm_scale(scale=1, mscale=1):
"""Compute softmax scale for YaRN RoPE extension."""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
Usage Examples
# Creating a Ministral3 config from a multimodal HuggingFace config
config_dict = {
"text_config": {
"hidden_size": 3072,
"intermediate_size": 9216,
"num_attention_heads": 32,
"num_hidden_layers": 26,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-5,
"rope_parameters": {
"factor": 16.0,
"mscale_all_dim": 1.0,
"rope_theta": 1000000.0,
"rope_type": "yarn",
},
"vocab_size": 131072,
"tie_word_embeddings": True,
},
"model_type": "ministral3",
}
config = Ministral3Config.from_dict(config_dict)
model = Mistral3ForConditionalGeneration(config)