Overview
sLSTMCell implements a stabilized LSTM cell with normalized exponential gating for improved training stability.
Description
sLSTMCell extends nn.Module and implements the sLSTM algorithm as described in the xLSTM paper (https://arxiv.org/pdf/2407.10240). It replaces standard sigmoid gating for input and forget gates with a normalized exponential gating mechanism, where gate activations are centered, exponentiated, and renormalized. This stabilization technique helps prevent vanishing and exploding gradients. The cell uses combined input and hidden weight matrices (4 * hidden_size) for computing all four gates (input, forget, cell candidate, output) in a single linear pass, with optional layer normalization and dropout.
Usage
Use sLSTMCell as the fundamental recurrent building block for stabilized LSTM layers. It is designed for cases where training stability with exponential gating is preferred over standard sigmoid gating.
Code Reference
Source Location
Signature
class sLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True):
def reset_parameters(self):
def normalized_exp_gate(self, pre_gate):
def forward(self, x, h_prev, c_prev):
def init_hidden(self, batch_size, device=None):
Import
from pytorch_forecasting.layers._recurrent._slstm.cell import sLSTMCell
I/O Contract
Inputs
__init__
| Name |
Type |
Required |
Description
|
| input_size |
int |
Yes |
Number of input features for the cell.
|
| hidden_size |
int |
Yes |
Number of features in the hidden state of the cell.
|
| dropout |
float |
No |
Dropout probability for input and hidden state. Defaults to 0.0.
|
| use_layer_norm |
bool |
No |
Whether to use layer normalization for internal computations. Defaults to True.
|
forward
| Name |
Type |
Required |
Description
|
| x |
torch.Tensor |
Yes |
Input tensor of shape (batch_size, input_size).
|
| h_prev |
torch.Tensor |
Yes |
Previous hidden state tensor of shape (batch_size, hidden_size).
|
| c_prev |
torch.Tensor |
Yes |
Previous cell state tensor of shape (batch_size, hidden_size).
|
Outputs
forward
| Name |
Type |
Description
|
| h |
torch.Tensor |
Updated hidden state tensor of shape (batch_size, hidden_size).
|
| c |
torch.Tensor |
Updated cell state tensor of shape (batch_size, hidden_size).
|
init_hidden
| Name |
Type |
Description
|
| (h, c) |
tuple of torch.Tensor |
Tuple of zero-initialized hidden and cell states, each of shape (batch_size, hidden_size).
|
Usage Examples
import torch
from pytorch_forecasting.layers._recurrent._slstm.cell import sLSTMCell
cell = sLSTMCell(input_size=32, hidden_size=64, dropout=0.1, use_layer_norm=True)
batch_size = 16
h, c = cell.init_hidden(batch_size)
x_t = torch.randn(batch_size, 32) # single time step input
h, c = cell(x_t, h, c)
# h shape: (16, 64)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.