Overview
mLSTMCell implements a Matrix Long Short-Term Memory cell with query-key-value attention gating and matrix memory.
Description
mLSTMCell extends nn.Module and implements the mLSTM algorithm as described in the xLSTM paper (https://arxiv.org/pdf/2407.10240). Unlike a standard LSTM, it uses query, key, and value projections (similar to attention mechanisms) along with input, forget, and output gates to produce a matrix-valued cell state. The cell maintains three states: hidden (h), cell (c), and normalized (n), and supports optional layer normalization and dropout.
Usage
Use mLSTMCell as a building block for matrix-memory LSTM layers. It is suitable for sequence modeling tasks where richer memory representations are needed compared to standard scalar-gated LSTM cells.
Code Reference
Source Location
Signature
class mLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True):
def forward(self, x, h_prev, c_prev, n_prev):
def init_hidden(self, batch_size, device=None):
Import
from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell
I/O Contract
Inputs
__init__
| Name |
Type |
Required |
Description
|
| input_size |
int |
Yes |
Size of the input feature vector.
|
| hidden_size |
int |
Yes |
Number of hidden units in the LSTM cell.
|
| dropout |
float |
No |
Dropout rate applied to inputs and hidden states. Defaults to 0.2.
|
| layer_norm |
bool |
No |
If True, apply Layer Normalization to gates and interactions. Defaults to True.
|
forward
| Name |
Type |
Required |
Description
|
| x |
torch.Tensor |
Yes |
Input tensor of shape (batch_size, input_size).
|
| h_prev |
torch.Tensor |
Yes |
Previous hidden state of shape (batch_size, hidden_size).
|
| c_prev |
torch.Tensor |
Yes |
Previous cell state of shape (batch_size, hidden_size).
|
| n_prev |
torch.Tensor |
Yes |
Previous normalized state of shape (batch_size, hidden_size).
|
Outputs
forward
| Name |
Type |
Description
|
| h |
torch.Tensor |
Current hidden state of shape (batch_size, hidden_size).
|
| c |
torch.Tensor |
Current cell state of shape (batch_size, hidden_size).
|
| n |
torch.Tensor |
Current normalized state of shape (batch_size, hidden_size).
|
init_hidden
| Name |
Type |
Description
|
| (h, c, n) |
tuple of torch.Tensor |
Tuple of zero-initialized hidden, cell, and normalization states, each of shape (batch_size, hidden_size).
|
Usage Examples
import torch
from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell
cell = mLSTMCell(input_size=32, hidden_size=64, dropout=0.1, layer_norm=True)
batch_size = 16
h, c, n = cell.init_hidden(batch_size)
x_t = torch.randn(batch_size, 32) # single time step input
h, c, n = cell(x_t, h, c, n)
# h shape: (16, 64)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.