Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Recommenders team Recommenders SASRec Model

From Leeroopedia


Knowledge Sources
Domains Sequential Recommendation, Self-Attention, Transformer Models
Last Updated 2026-02-10 00:00 GMT

Overview

The SASRec module implements the Self-Attentive Sequential Recommendation model in PyTorch, a Transformer-based architecture for predicting the next item in a user's interaction sequence.

Description

This file contains the complete implementation of the SASRec model (Kang and McAuley, ICDM 2018) along with all supporting Transformer components. The architecture processes user interaction sequences through a stack of self-attention blocks to learn sequential patterns for next-item prediction.

Core Components:

  • MultiHeadAttention: Implements scaled dot-product attention with multiple heads, including key masking (to ignore padding), causal masking (future blinding via lower-triangular mask), query masking, and residual connections.
  • PointWiseFeedForward: Two-layer 1D convolution network with ReLU activation, dropout, and residual connection, applied point-wise to each position in the sequence.
  • EncoderLayer: Combines MultiHeadAttention and PointWiseFeedForward with custom LayerNormalization into a single Transformer encoder block.
  • Encoder: Stacks multiple EncoderLayer blocks using nn.ModuleList for configurable depth.
  • LayerNormalization: Custom implementation using learnable gamma and beta parameters with mean-variance normalization.
  • pad_sequences: Utility function for padding/truncating sequences to a fixed length, supporting both pre and post padding strategies.

SASREC Model Architecture:

  • Embedding Layer: Combines learned item embeddings (scaled by sqrt of embedding_dim) with positional embeddings to encode both item identity and position in the sequence.
  • Encoder Stack: Processes embedded sequences through multiple self-attention blocks with causal masking to enforce the autoregressive property.
  • Prediction Layer: Computes logits via dot product between sequence embeddings and item embeddings for both positive and negative items.

Training and Evaluation:

  • Loss function: Binary cross-entropy with negative sampling, where positive items are the next items in the sequence and negatives are randomly sampled.
  • train_model: Full training loop with Adam optimizer, L2 regularization via weight decay, GPU support, periodic validation evaluation (NDCG@10, HR@10), and progress bars via tqdm.
  • evaluate / evaluate_valid: Batched evaluation computing NDCG@10 and HR@10 metrics with configurable negative sampling for ranking.

This class also serves as the base class for the SSEPT model, which extends it with user embeddings.

Usage

Use this model for sequential recommendation tasks where user behavior is modeled as an ordered sequence of item interactions. It is particularly effective for datasets where temporal ordering of interactions is important and the goal is to predict the next item a user will interact with. Suitable for implicit feedback scenarios with large item catalogs.

Code Reference

Source Location

Signature

class MultiHeadAttention(nn.Module):
    def __init__(self, attention_dim, num_heads, dropout_rate)
    def forward(self, queries, keys)

class PointWiseFeedForward(nn.Module):
    def __init__(self, embedding_dim, conv_dims, dropout_rate)
    def forward(self, x)

class EncoderLayer(nn.Module):
    def __init__(self, seq_max_len, embedding_dim, attention_dim, num_heads, conv_dims, dropout_rate)
    def forward(self, x, training, mask)

class Encoder(nn.Module):
    def __init__(self, num_layers, seq_max_len, embedding_dim, attention_dim, num_heads, conv_dims, dropout_rate)
    def forward(self, x, training, mask)

class LayerNormalization(nn.Module):
    def __init__(self, seq_max_len, embedding_dim, epsilon)
    def forward(self, x)

def pad_sequences(sequences, maxlen, padding='pre', truncating='pre', value=0)

class SASREC(nn.Module):
    def __init__(self, **kwargs)
    def embedding(self, input_seq)
    def forward(self, x, training=True)
    def predict(self, inputs)
    def loss_function(self, pos_logits, neg_logits, istarget)
    def create_combined_dataset(self, u, seq, pos, neg)
    def train_model(self, dataset, sampler, num_epochs=10, batch_size=128,
                    learning_rate=0.001, val_epoch=0, eval_batch_size=256, verbose=True)
    def evaluate(self, dataset, seed=None, eval_batch_size=256)
    def evaluate_valid(self, dataset, seed=None, eval_batch_size=256)

Import

from recommenders.models.sasrec.model import SASREC
from recommenders.models.sasrec.model import MultiHeadAttention, PointWiseFeedForward
from recommenders.models.sasrec.model import Encoder, EncoderLayer, LayerNormalization
from recommenders.models.sasrec.model import pad_sequences

I/O Contract

Inputs

Name Type Required Description
item_num int Yes Number of items in the dataset
seq_max_len int No Maximum sequence length for user history; default 100
num_blocks int No Number of Transformer encoder blocks; default 2
embedding_dim int No Item embedding dimension; default 100
attention_dim int No Transformer attention dimension; default 100
attention_num_heads int No Number of attention heads; default 1
conv_dims list No Dimensions of the feedforward layers; default [100, 100]
dropout_rate float No Dropout probability; default 0.5
l2_reg float No L2 regularization coefficient (applied as weight decay); default 0.0
num_neg_test int No Number of negative examples used during evaluation; default 100

Outputs

Name Type Description
forward() (torch.Tensor, torch.Tensor, torch.Tensor) Positive logits, negative logits, and target mask for loss computation
predict() torch.Tensor Logits of shape (batch, num_candidates) for candidate items
train_model() dict Training history with 'loss', 'val_ndcg', and 'val_hr' lists
evaluate() tuple of float (NDCG@10, HR@10) metrics on the test set
evaluate_valid() tuple of float (NDCG@10, HR@10) metrics on the validation set

Usage Examples

Basic Usage

from recommenders.models.sasrec.model import SASREC

# Initialize the SASRec model
model = SASREC(
    item_num=10000,
    seq_max_len=50,
    num_blocks=2,
    embedding_dim=64,
    attention_dim=64,
    attention_num_heads=2,
    conv_dims=[64, 64],
    dropout_rate=0.2,
    l2_reg=1e-6,
    num_neg_test=100,
)

# Train the model
history = model.train_model(
    dataset=dataset,
    sampler=warp_sampler,
    num_epochs=20,
    batch_size=128,
    learning_rate=0.001,
    val_epoch=5,
)

# Evaluate on the test set
ndcg, hr = model.evaluate(dataset, seed=42)
print(f"Test NDCG@10: {ndcg:.4f}, HR@10: {hr:.4f}")

# Predict for specific inputs
predictions = model.predict({
    "input_seq": input_sequences,
    "candidate": candidate_items,
})

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment