Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting mLSTMCell

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

mLSTMCell implements a Matrix Long Short-Term Memory cell with query-key-value attention gating and matrix memory.

Description

mLSTMCell extends nn.Module and implements the mLSTM algorithm as described in the xLSTM paper (https://arxiv.org/pdf/2407.10240). Unlike a standard LSTM, it uses query, key, and value projections (similar to attention mechanisms) along with input, forget, and output gates to produce a matrix-valued cell state. The cell maintains three states: hidden (h), cell (c), and normalized (n), and supports optional layer normalization and dropout.

Usage

Use mLSTMCell as a building block for matrix-memory LSTM layers. It is suitable for sequence modeling tasks where richer memory representations are needed compared to standard scalar-gated LSTM cells.

Code Reference

Source Location

Signature

class mLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True):
    def forward(self, x, h_prev, c_prev, n_prev):
    def init_hidden(self, batch_size, device=None):

Import

from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell

I/O Contract

Inputs

__init__

Name Type Required Description
input_size int Yes Size of the input feature vector.
hidden_size int Yes Number of hidden units in the LSTM cell.
dropout float No Dropout rate applied to inputs and hidden states. Defaults to 0.2.
layer_norm bool No If True, apply Layer Normalization to gates and interactions. Defaults to True.

forward

Name Type Required Description
x torch.Tensor Yes Input tensor of shape (batch_size, input_size).
h_prev torch.Tensor Yes Previous hidden state of shape (batch_size, hidden_size).
c_prev torch.Tensor Yes Previous cell state of shape (batch_size, hidden_size).
n_prev torch.Tensor Yes Previous normalized state of shape (batch_size, hidden_size).

Outputs

forward

Name Type Description
h torch.Tensor Current hidden state of shape (batch_size, hidden_size).
c torch.Tensor Current cell state of shape (batch_size, hidden_size).
n torch.Tensor Current normalized state of shape (batch_size, hidden_size).

init_hidden

Name Type Description
(h, c, n) tuple of torch.Tensor Tuple of zero-initialized hidden, cell, and normalization states, each of shape (batch_size, hidden_size).

Usage Examples

import torch
from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell

cell = mLSTMCell(input_size=32, hidden_size=64, dropout=0.1, layer_norm=True)

batch_size = 16
h, c, n = cell.init_hidden(batch_size)

x_t = torch.randn(batch_size, 32)  # single time step input
h, c, n = cell(x_t, h, c, n)
# h shape: (16, 64)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment