Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm RNN State

From Leeroopedia


Knowledge Sources
Domains Machine Learning, State Space Models, Memory Management
Last Updated 2026-02-09 19:00 GMT

Overview

Provides a managed RNN state container for State Space Models (SSMs) in MLC-LLM, enabling efficient per-layer, per-sequence state storage with batched get/set operations and circular history buffering.

Description

This module implements the RNNState class, which extends TVM Relax's Object type to manage recurrent state tensors used in State Space Models (such as Mamba, RWKV, and similar architectures). The class provides a structured mechanism for creating, reading, and writing state tensors across multiple hidden layers and multiple sequences in a batch, with support for a circular history buffer.

The core abstraction manages state storage as tensors of shape (max_batch_size, max_history, *state_shape) per state component per layer. The max_history dimension enables circular buffering of historical states, useful for models that condition on previous time steps.

Key features include:

  • State Creation (create static method) -- Allocates the state storage and registers TIR (Tensor IR) primitive functions for efficient get/set operations. The created state is initialized with provided initial values and registered in the TVM Relax block builder.
  • State Retrieval (get method) -- Retrieves state tensors for the current batch using sequence slot IDs and history slot IDs. Calls the TVM runtime packed function "vm.builtin.rnn_state_get" via a DPS (destination-passing style) call.
  • State Update (set method) -- Writes updated state tensors back to storage. Uses circular indexing: the history slot is incremented by 1 modulo max_history, enabling ring-buffer behavior.
  • TIR Function Generation -- The create_get_func and create_set_func static methods dynamically generate TIR primitive functions customized to the state tensor shape and dtype. These functions handle batch-indexed copy operations between the global storage buffer and per-batch output/input tensors. Separate code paths handle one-dimensional and higher-dimensional state shapes for compatibility with Python versions before 3.11 (avoiding unpacking in subscripts).

Usage

Use this module when implementing State Space Models (SSMs) or any recurrent architecture in MLC-LLM that requires managed per-layer, per-sequence state. It is typically instantiated within model implementations like Mamba or RWKV during the model construction phase and used throughout forward passes to retrieve and update recurrent states.

Code Reference

Source Location

Signature

class RNNState(Object):
    @staticmethod
    def create(
        max_batch_size: tir.Var,
        num_hidden_layers: int,
        max_history: int,
        init_values: Sequence[Tensor],
        name: str = "rnn_state",
    ) -> "RNNState": ...

    def get(
        self,
        layer_id: int,
        state_id: int,
        shape: Sequence[tir.PrimExpr],
        dtype: str,
    ) -> Tensor: ...

    def set(
        self,
        layer_id: int,
        state_id: int,
        value: Tensor,
    ) -> "RNNState": ...

    @staticmethod
    def create_get_func(
        shape: Sequence[Union[int, tir.Var]],
        dtype: str,
        max_batch_size: Union[int, tir.Var],
        max_history: Union[int, tir.Var],
        state_id: int,
    ) -> tir.PrimFunc: ...

    @staticmethod
    def create_set_func(
        shape: Sequence[Union[int, tir.Var]],
        dtype: str,
        max_batch_size: Union[int, tir.Var],
        max_history: Union[int, tir.Var],
        state_id: int,
    ) -> tir.PrimFunc: ...

Import

from mlc_llm.nn.rnn_state import RNNState

I/O Contract

Method Input Output Description
create max_batch_size, num_hidden_layers, max_history, init_values, name RNNState object Creates a new RNN state with allocated storage and registered TIR get/set functions
get layer_id: int, state_id: int, shape, dtype Tensor[batch_size, *state_shape] Retrieves the state tensor for the specified layer and state index; uses seq/history slot IDs
set layer_id: int, state_id: int, value: Tensor RNNState (updated) Writes the state tensor to storage with circular history indexing (history_slot + 1) % max_history
Storage Layout Shape Description
Per-state storage (max_batch_size, max_history, *state_shape) 3D+ buffer indexed by sequence slot, history slot, and state dimensions
seq_slot_ids (batch_size,) int32 Maps batch indices to sequence storage slots
history_slot_ids (batch_size,) int32 Maps batch indices to history buffer positions (circular)
TIR Function Purpose
rnn_state_get_{id} Copy from storage[seq_id, history_id, ...] to output[batch_idx, ...]
rnn_state_set_{id} Copy from input[batch_idx, ...] to storage[seq_id, (history_id + 1) % max_history, ...]

Usage Examples

from mlc_llm.nn.rnn_state import RNNState

# Create an RNN state for a model with 24 layers,
# max batch size of 8, and 1 history step
rnn_state = RNNState.create(
    max_batch_size=8,
    num_hidden_layers=24,
    max_history=1,
    init_values=[initial_hidden_state, initial_conv_state],
    name="mamba_state",
)

# In the forward pass of layer i, retrieve the hidden state
hidden = rnn_state.get(layer_id=i, state_id=0, shape=(d_inner,), dtype="float16")

# After computation, write the updated state back
rnn_state = rnn_state.set(layer_id=i, state_id=0, value=new_hidden)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment