Implementation:LLMBook zh LLMBook zh github io LlamaDecoderLayer
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Model_Architecture |
| Last Updated | 2026-02-08 04:29 GMT |
Overview
Concrete tool for a single LLaMA Transformer decoder block provided by PyTorch as a custom nn.Module.
Description
The LlamaDecoderLayer class implements a single Transformer decoder layer in the Pre-Norm style. It contains: (1) `input_layernorm` (RMSNorm applied before self-attention), (2) `self_attn` (LlamaAttention module), (3) `post_attention_layernorm` (RMSNorm applied before the MLP), and (4) `mlp` (LlamaMLP feed-forward network). The forward pass applies normalization, self-attention, and a residual connection, then normalization, MLP, and another residual connection. Multiple instances of this layer are stacked to form the complete LlamaModel.
Usage
Import this class when studying the internal structure of each Transformer block in LLaMA-family models. Each decoder layer is initialized with a config and a layer index, and processes hidden states sequentially as part of the model stack.
Code Reference
Source Location
- Repository: LLMBook-zh
- File: code/5.6 LLaMALayer.py
- Lines: 1-43
Signature
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
"""
Args:
config: LlamaConfig with hidden_size, rms_norm_eps, etc.
layer_idx: Index of this layer in the model stack.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, ...]:
"""
Args:
hidden_states: Input tensor of shape (batch, seq_len, hidden_size).
attention_mask: Causal attention mask.
position_ids: Position indices for RoPE.
Returns:
Tuple containing output hidden states of shape (batch, seq_len, hidden_size).
"""
Import
from torch import nn
# LlamaDecoderLayer defined locally in code/5.6 LLaMALayer.py
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | LlamaConfig | Yes | Model configuration (constructor) |
| layer_idx | int | Yes | Layer index in model stack (constructor) |
| hidden_states | torch.Tensor | Yes | Input hidden states (batch, seq_len, hidden_size) |
| attention_mask | torch.Tensor | No | Causal attention mask |
| position_ids | torch.LongTensor | No | Position indices for RoPE |
Outputs
| Name | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor | Output after attention + MLP with residual connections |
Usage Examples
import torch
from torch import nn
# Decoder layer is typically not used standalone
# It is instantiated inside LlamaModel:
# self.layers = nn.ModuleList(
# [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
# )
# Forward pass through a single layer
# hidden_states = layer(hidden_states, attention_mask=mask, position_ids=pos_ids)[0]