Implementation:Mlc ai Mlc llm RNN State
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, State Space Models, Memory Management |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Provides a managed RNN state container for State Space Models (SSMs) in MLC-LLM, enabling efficient per-layer, per-sequence state storage with batched get/set operations and circular history buffering.
Description
This module implements the RNNState class, which extends TVM Relax's Object type to manage recurrent state tensors used in State Space Models (such as Mamba, RWKV, and similar architectures). The class provides a structured mechanism for creating, reading, and writing state tensors across multiple hidden layers and multiple sequences in a batch, with support for a circular history buffer.
The core abstraction manages state storage as tensors of shape (max_batch_size, max_history, *state_shape) per state component per layer. The max_history dimension enables circular buffering of historical states, useful for models that condition on previous time steps.
Key features include:
- State Creation (create static method) -- Allocates the state storage and registers TIR (Tensor IR) primitive functions for efficient get/set operations. The created state is initialized with provided initial values and registered in the TVM Relax block builder.
- State Retrieval (get method) -- Retrieves state tensors for the current batch using sequence slot IDs and history slot IDs. Calls the TVM runtime packed function "vm.builtin.rnn_state_get" via a DPS (destination-passing style) call.
- State Update (set method) -- Writes updated state tensors back to storage. Uses circular indexing: the history slot is incremented by 1 modulo max_history, enabling ring-buffer behavior.
- TIR Function Generation -- The create_get_func and create_set_func static methods dynamically generate TIR primitive functions customized to the state tensor shape and dtype. These functions handle batch-indexed copy operations between the global storage buffer and per-batch output/input tensors. Separate code paths handle one-dimensional and higher-dimensional state shapes for compatibility with Python versions before 3.11 (avoiding unpacking in subscripts).
Usage
Use this module when implementing State Space Models (SSMs) or any recurrent architecture in MLC-LLM that requires managed per-layer, per-sequence state. It is typically instantiated within model implementations like Mamba or RWKV during the model construction phase and used throughout forward passes to retrieve and update recurrent states.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/nn/rnn_state.py
Signature
class RNNState(Object):
@staticmethod
def create(
max_batch_size: tir.Var,
num_hidden_layers: int,
max_history: int,
init_values: Sequence[Tensor],
name: str = "rnn_state",
) -> "RNNState": ...
def get(
self,
layer_id: int,
state_id: int,
shape: Sequence[tir.PrimExpr],
dtype: str,
) -> Tensor: ...
def set(
self,
layer_id: int,
state_id: int,
value: Tensor,
) -> "RNNState": ...
@staticmethod
def create_get_func(
shape: Sequence[Union[int, tir.Var]],
dtype: str,
max_batch_size: Union[int, tir.Var],
max_history: Union[int, tir.Var],
state_id: int,
) -> tir.PrimFunc: ...
@staticmethod
def create_set_func(
shape: Sequence[Union[int, tir.Var]],
dtype: str,
max_batch_size: Union[int, tir.Var],
max_history: Union[int, tir.Var],
state_id: int,
) -> tir.PrimFunc: ...
Import
from mlc_llm.nn.rnn_state import RNNState
I/O Contract
| Method | Input | Output | Description |
|---|---|---|---|
| create | max_batch_size, num_hidden_layers, max_history, init_values, name | RNNState object | Creates a new RNN state with allocated storage and registered TIR get/set functions |
| get | layer_id: int, state_id: int, shape, dtype | Tensor[batch_size, *state_shape] | Retrieves the state tensor for the specified layer and state index; uses seq/history slot IDs |
| set | layer_id: int, state_id: int, value: Tensor | RNNState (updated) | Writes the state tensor to storage with circular history indexing (history_slot + 1) % max_history |
| Storage Layout | Shape | Description |
|---|---|---|
| Per-state storage | (max_batch_size, max_history, *state_shape) | 3D+ buffer indexed by sequence slot, history slot, and state dimensions |
| seq_slot_ids | (batch_size,) int32 | Maps batch indices to sequence storage slots |
| history_slot_ids | (batch_size,) int32 | Maps batch indices to history buffer positions (circular) |
| TIR Function | Purpose |
|---|---|
| rnn_state_get_{id} | Copy from storage[seq_id, history_id, ...] to output[batch_idx, ...] |
| rnn_state_set_{id} | Copy from input[batch_idx, ...] to storage[seq_id, (history_id + 1) % max_history, ...] |
Usage Examples
from mlc_llm.nn.rnn_state import RNNState
# Create an RNN state for a model with 24 layers,
# max batch size of 8, and 1 history step
rnn_state = RNNState.create(
max_batch_size=8,
num_hidden_layers=24,
max_history=1,
init_values=[initial_hidden_state, initial_conv_state],
name="mamba_state",
)
# In the forward pass of layer i, retrieve the hidden state
hidden = rnn_state.get(layer_id=i, state_id=0, shape=(d_inner,), dtype="float16")
# After computation, write the updated state back
rnn_state = rnn_state.set(layer_id=i, state_id=0, value=new_hidden)