Implementation:Recommenders team Recommenders SASRec Model
| Knowledge Sources | |
|---|---|
| Domains | Sequential Recommendation, Self-Attention, Transformer Models |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
The SASRec module implements the Self-Attentive Sequential Recommendation model in PyTorch, a Transformer-based architecture for predicting the next item in a user's interaction sequence.
Description
This file contains the complete implementation of the SASRec model (Kang and McAuley, ICDM 2018) along with all supporting Transformer components. The architecture processes user interaction sequences through a stack of self-attention blocks to learn sequential patterns for next-item prediction.
Core Components:
- MultiHeadAttention: Implements scaled dot-product attention with multiple heads, including key masking (to ignore padding), causal masking (future blinding via lower-triangular mask), query masking, and residual connections.
- PointWiseFeedForward: Two-layer 1D convolution network with ReLU activation, dropout, and residual connection, applied point-wise to each position in the sequence.
- EncoderLayer: Combines MultiHeadAttention and PointWiseFeedForward with custom LayerNormalization into a single Transformer encoder block.
- Encoder: Stacks multiple EncoderLayer blocks using nn.ModuleList for configurable depth.
- LayerNormalization: Custom implementation using learnable gamma and beta parameters with mean-variance normalization.
- pad_sequences: Utility function for padding/truncating sequences to a fixed length, supporting both pre and post padding strategies.
SASREC Model Architecture:
- Embedding Layer: Combines learned item embeddings (scaled by sqrt of embedding_dim) with positional embeddings to encode both item identity and position in the sequence.
- Encoder Stack: Processes embedded sequences through multiple self-attention blocks with causal masking to enforce the autoregressive property.
- Prediction Layer: Computes logits via dot product between sequence embeddings and item embeddings for both positive and negative items.
Training and Evaluation:
- Loss function: Binary cross-entropy with negative sampling, where positive items are the next items in the sequence and negatives are randomly sampled.
- train_model: Full training loop with Adam optimizer, L2 regularization via weight decay, GPU support, periodic validation evaluation (NDCG@10, HR@10), and progress bars via tqdm.
- evaluate / evaluate_valid: Batched evaluation computing NDCG@10 and HR@10 metrics with configurable negative sampling for ranking.
This class also serves as the base class for the SSEPT model, which extends it with user embeddings.
Usage
Use this model for sequential recommendation tasks where user behavior is modeled as an ordered sequence of item interactions. It is particularly effective for datasets where temporal ordering of interactions is important and the goal is to predict the next item a user will interact with. Suitable for implicit feedback scenarios with large item catalogs.
Code Reference
Source Location
- Repository: Recommenders
- File: recommenders/models/sasrec/model.py
- Lines: 1-972
Signature
class MultiHeadAttention(nn.Module):
def __init__(self, attention_dim, num_heads, dropout_rate)
def forward(self, queries, keys)
class PointWiseFeedForward(nn.Module):
def __init__(self, embedding_dim, conv_dims, dropout_rate)
def forward(self, x)
class EncoderLayer(nn.Module):
def __init__(self, seq_max_len, embedding_dim, attention_dim, num_heads, conv_dims, dropout_rate)
def forward(self, x, training, mask)
class Encoder(nn.Module):
def __init__(self, num_layers, seq_max_len, embedding_dim, attention_dim, num_heads, conv_dims, dropout_rate)
def forward(self, x, training, mask)
class LayerNormalization(nn.Module):
def __init__(self, seq_max_len, embedding_dim, epsilon)
def forward(self, x)
def pad_sequences(sequences, maxlen, padding='pre', truncating='pre', value=0)
class SASREC(nn.Module):
def __init__(self, **kwargs)
def embedding(self, input_seq)
def forward(self, x, training=True)
def predict(self, inputs)
def loss_function(self, pos_logits, neg_logits, istarget)
def create_combined_dataset(self, u, seq, pos, neg)
def train_model(self, dataset, sampler, num_epochs=10, batch_size=128,
learning_rate=0.001, val_epoch=0, eval_batch_size=256, verbose=True)
def evaluate(self, dataset, seed=None, eval_batch_size=256)
def evaluate_valid(self, dataset, seed=None, eval_batch_size=256)
Import
from recommenders.models.sasrec.model import SASREC
from recommenders.models.sasrec.model import MultiHeadAttention, PointWiseFeedForward
from recommenders.models.sasrec.model import Encoder, EncoderLayer, LayerNormalization
from recommenders.models.sasrec.model import pad_sequences
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| item_num | int | Yes | Number of items in the dataset |
| seq_max_len | int | No | Maximum sequence length for user history; default 100 |
| num_blocks | int | No | Number of Transformer encoder blocks; default 2 |
| embedding_dim | int | No | Item embedding dimension; default 100 |
| attention_dim | int | No | Transformer attention dimension; default 100 |
| attention_num_heads | int | No | Number of attention heads; default 1 |
| conv_dims | list | No | Dimensions of the feedforward layers; default [100, 100] |
| dropout_rate | float | No | Dropout probability; default 0.5 |
| l2_reg | float | No | L2 regularization coefficient (applied as weight decay); default 0.0 |
| num_neg_test | int | No | Number of negative examples used during evaluation; default 100 |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() | (torch.Tensor, torch.Tensor, torch.Tensor) | Positive logits, negative logits, and target mask for loss computation |
| predict() | torch.Tensor | Logits of shape (batch, num_candidates) for candidate items |
| train_model() | dict | Training history with 'loss', 'val_ndcg', and 'val_hr' lists |
| evaluate() | tuple of float | (NDCG@10, HR@10) metrics on the test set |
| evaluate_valid() | tuple of float | (NDCG@10, HR@10) metrics on the validation set |
Usage Examples
Basic Usage
from recommenders.models.sasrec.model import SASREC
# Initialize the SASRec model
model = SASREC(
item_num=10000,
seq_max_len=50,
num_blocks=2,
embedding_dim=64,
attention_dim=64,
attention_num_heads=2,
conv_dims=[64, 64],
dropout_rate=0.2,
l2_reg=1e-6,
num_neg_test=100,
)
# Train the model
history = model.train_model(
dataset=dataset,
sampler=warp_sampler,
num_epochs=20,
batch_size=128,
learning_rate=0.001,
val_epoch=5,
)
# Evaluate on the test set
ndcg, hr = model.evaluate(dataset, seed=42)
print(f"Test NDCG@10: {ndcg:.4f}, HR@10: {hr:.4f}")
# Predict for specific inputs
predictions = model.predict({
"input_seq": input_sequences,
"candidate": candidate_items,
})