Overview
Llama2_Model implements the full Llama2 transformer architecture for TorchServe with distributed tensor parallelism support. It defines the complete model hierarchy from low-level components (RMSNorm, rotary embeddings) through attention and feed-forward blocks to the full Transformer and high-level Llama wrapper class. The module includes parallelization functions for distributing the model across multiple GPUs.
Description
The llama2.py module contains the complete Llama2 model implementation organized as a layered class hierarchy. Each class encapsulates a specific component of the transformer architecture, and utility functions provide rotary positional embeddings and tensor parallelism support.
Architecture Hierarchy
- ModelArgs (dataclass): Configuration parameters for the model (dimensions, layers, heads, vocabulary size)
- RMSNorm: Root Mean Square Layer Normalization (replacing LayerNorm)
- Attention: Multi-head attention with rotary positional embeddings (RoPE), key-value caching, and grouped-query attention (GQA) via
repeat_kv()
- FeedForward: SwiGLU feed-forward network with gated linear units
- TransformerBlock: Single transformer layer combining Attention + FeedForward + RMSNorm
- Transformer: Full transformer stack with token embedding, positional encoding, and output projection
- Llama: High-level wrapper that combines Transformer with tokenizer and generation logic
Parallelism Support
- tp_llama(): Entry point for creating a tensor-parallel Llama model
- parallelize_llama_attn_block(): Distributes attention and feed-forward weights across GPUs using column and row parallel linear layers
Usage
from examples.large_models.tp_llama.llama2 import (
ModelArgs,
Transformer,
Llama,
tp_llama,
)
Code Reference
Source Location
| File |
Lines |
Repository
|
examples/large_models/tp_llama/llama2.py |
L1-701 |
pytorch/serve
|
examples/large_models/tp_llama/llama2.py |
L43-55 |
ModelArgs dataclass
|
examples/large_models/tp_llama/llama2.py |
L59-70 |
RMSNorm class
|
examples/large_models/tp_llama/llama2.py |
L111-210 |
Attention class
|
examples/large_models/tp_llama/llama2.py |
L213-236 |
FeedForward class
|
examples/large_models/tp_llama/llama2.py |
L239-278 |
TransformerBlock class
|
examples/large_models/tp_llama/llama2.py |
L281-391 |
Transformer class
|
examples/large_models/tp_llama/llama2.py |
L611-701 |
Llama class
|
Signature
@dataclass
class ModelArgs:
"""
Configuration dataclass for Llama2 model architecture.
Attributes:
dim (int): Model embedding dimension.
n_layers (int): Number of transformer layers.
n_heads (int): Number of attention heads.
n_kv_heads (int|None): Number of key-value heads for GQA (None = MHA).
vocab_size (int): Vocabulary size (-1 for dynamic).
multiple_of (int): FFN hidden dim rounded to this multiple.
ffn_dim_multiplier (float|None): Optional FFN dimension scaling factor.
norm_eps (float): Epsilon for RMSNorm.
max_batch_size (int): Maximum batch size for KV cache allocation.
max_seq_len (int): Maximum sequence length.
"""
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: int | None = None
vocab_size: int = -1
multiple_of: int = 256
ffn_dim_multiplier: float | None = None
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
class RMSNorm(torch.nn.Module):
"""
Root Mean Square Layer Normalization.
Normalizes inputs by their RMS value and applies a learnable scale.
More efficient than LayerNorm as it omits mean subtraction.
Args:
dim (int): Feature dimension to normalize.
eps (float): Small constant for numerical stability.
"""
def __init__(self, dim: int, eps: float = 1e-6):
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
class Attention(torch.nn.Module):
"""
Multi-head attention with rotary positional embeddings.
Supports grouped-query attention (GQA) where n_kv_heads < n_heads.
Uses pre-allocated KV cache for efficient autoregressive generation.
Args:
args (ModelArgs): Model configuration.
"""
def __init__(self, args: ModelArgs):
...
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: torch.Tensor | None,
) -> torch.Tensor:
"""
Compute attention output.
Args:
x (torch.Tensor): Input tensor of shape (batch, seq_len, dim).
start_pos (int): Current position in the sequence (for KV cache).
freqs_cis (torch.Tensor): Precomputed rotary embedding frequencies.
mask (torch.Tensor|None): Causal attention mask.
Returns:
torch.Tensor: Attention output of shape (batch, seq_len, dim).
"""
...
class FeedForward(torch.nn.Module):
"""
SwiGLU feed-forward network.
Uses gated linear units with SiLU activation:
output = (silu(w1(x)) * w3(x)) @ w2
Args:
dim (int): Input/output dimension.
hidden_dim (int): Intermediate hidden dimension.
multiple_of (int): Round hidden_dim up to this multiple.
ffn_dim_multiplier (float|None): Optional hidden dim scaling.
"""
def __init__(self, dim: int, hidden_dim: int, multiple_of: int,
ffn_dim_multiplier: float | None):
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
class TransformerBlock(torch.nn.Module):
"""
Single transformer layer: Attention + FeedForward with residual connections.
Architecture:
x -> RMSNorm -> Attention -> + (residual) -> RMSNorm -> FeedForward -> + (residual)
Args:
layer_id (int): Layer index.
args (ModelArgs): Model configuration.
"""
def __init__(self, layer_id: int, args: ModelArgs):
...
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: torch.Tensor | None,
) -> torch.Tensor:
...
class Transformer(torch.nn.Module):
"""
Full Llama2 transformer model.
Combines token embedding, positional encoding (RoPE), stacked
TransformerBlocks, final RMSNorm, and linear output projection.
Args:
params (ModelArgs): Model configuration dataclass.
"""
def __init__(self, params: ModelArgs):
...
def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:
"""
Forward pass through the full transformer.
Args:
tokens (torch.Tensor): Input token IDs of shape (batch, seq_len).
start_pos (int): Current position for KV cache indexing.
Returns:
torch.Tensor: Logits of shape (batch, seq_len, vocab_size).
"""
...
class Llama:
"""
High-level Llama2 wrapper combining Transformer, tokenizer, and generation.
Provides text_completion() and chat_completion() methods that handle
tokenization, generation, and decoding.
Args:
model (Transformer): The transformer model instance.
tokenizer: SentencePiece tokenizer.
"""
def __init__(self, model: Transformer, tokenizer):
...
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Precompute rotary embedding frequency tensor.
Args:
dim (int): Head dimension (dim // n_heads).
end (int): Maximum sequence length.
theta (float): Base frequency (default 10000.0).
Returns:
torch.Tensor: Complex frequency tensor of shape (end, dim // 2).
"""
...
def apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary positional embeddings to query and key tensors.
Args:
xq (torch.Tensor): Query tensor.
xk (torch.Tensor): Key tensor.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
Returns:
tuple: Rotated (query, key) tensors.
"""
...
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeat key-value heads for grouped-query attention.
When n_kv_heads < n_heads, each KV head is repeated n_rep times
to match the number of query heads.
Args:
x (torch.Tensor): KV tensor of shape (batch, n_kv_heads, seq_len, head_dim).
n_rep (int): Number of times to repeat each KV head.
Returns:
torch.Tensor: Expanded tensor of shape (batch, n_heads, seq_len, head_dim).
"""
...
def tp_llama(model_args, checkpoint_path, tokenizer_path, world_size):
"""
Create a tensor-parallel Llama model distributed across multiple GPUs.
Args:
model_args (ModelArgs): Model configuration.
checkpoint_path (str): Path to model checkpoint.
tokenizer_path (str): Path to tokenizer model file.
world_size (int): Number of GPUs for tensor parallelism.
Returns:
Llama: Distributed Llama model instance.
"""
...
def parallelize_llama_attn_block(model, world_size):
"""
Distribute attention and feed-forward weights across GPUs.
Converts linear layers to column-parallel and row-parallel variants
for tensor parallelism.
Args:
model (Transformer): The transformer model to parallelize.
world_size (int): Number of parallel GPUs.
"""
...
Import
import torch
from dataclasses import dataclass
from typing import Optional
from examples.large_models.tp_llama.llama2 import (
ModelArgs,
RMSNorm,
Attention,
FeedForward,
TransformerBlock,
Transformer,
Llama,
precompute_freqs_cis,
apply_rotary_emb,
repeat_kv,
tp_llama,
parallelize_llama_attn_block,
)
I/O Contract
| Class / Function |
Input |
Output |
Notes
|
ModelArgs() |
Keyword arguments for model dimensions |
ModelArgs dataclass instance |
Lines 43-55; all fields have defaults
|
RMSNorm(dim, eps) |
dim: feature dimension; eps: stability constant |
Module; forward takes torch.Tensor, returns torch.Tensor |
Lines 59-70
|
Attention(args).forward(x, start_pos, freqs_cis, mask) |
Input tensor (batch, seq_len, dim), position, frequencies, mask |
torch.Tensor (batch, seq_len, dim) |
Lines 111-210; uses KV cache
|
FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier).forward(x) |
torch.Tensor (batch, seq_len, dim) |
torch.Tensor (batch, seq_len, dim) |
Lines 213-236; SwiGLU activation
|
TransformerBlock(layer_id, args).forward(x, start_pos, freqs_cis, mask) |
Input tensor, position, frequencies, mask |
torch.Tensor (batch, seq_len, dim) |
Lines 239-278; Attention + FFN + residual
|
Transformer(params).forward(tokens, start_pos) |
Token IDs (batch, seq_len), start position |
Logits (batch, seq_len, vocab_size) |
Lines 281-391; full model forward
|
Llama(model, tokenizer) |
Transformer model, SentencePiece tokenizer |
Llama wrapper instance |
Lines 611-701; high-level API
|
precompute_freqs_cis(dim, end, theta) |
Head dim, max seq len, base frequency |
Complex tensor (end, dim//2) |
Rotary embedding frequencies
|
apply_rotary_emb(xq, xk, freqs_cis) |
Query tensor, key tensor, frequency tensor |
Tuple of rotated (query, key) tensors |
Applies RoPE to Q and K
|
repeat_kv(x, n_rep) |
KV tensor (batch, n_kv_heads, seq, head_dim), repetition count |
Expanded tensor (batch, n_heads, seq, head_dim) |
For grouped-query attention
|
tp_llama(model_args, checkpoint_path, tokenizer_path, world_size) |
Model config, paths, GPU count |
Distributed Llama instance |
Creates tensor-parallel model
|
parallelize_llama_attn_block(model, world_size) |
Transformer model, GPU count |
None (modifies model in-place) |
Column/row parallel conversion
|
Model Configurations
| Model |
dim |
n_layers |
n_heads |
n_kv_heads |
vocab_size |
max_seq_len
|
| Llama-2-7B |
4096 |
32 |
32 |
32 |
32000 |
2048
|
| Llama-2-13B |
5120 |
40 |
40 |
40 |
32000 |
2048
|
| Llama-2-70B |
8192 |
80 |
64 |
8 |
32000 |
2048
|
Usage Examples
Example 1: Creating a Llama2-7B model
from examples.large_models.tp_llama.llama2 import ModelArgs, Transformer
args = ModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=32,
vocab_size=32000,
max_seq_len=2048,
max_batch_size=4,
)
model = Transformer(args)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Example 2: Tensor-parallel model creation
from examples.large_models.tp_llama.llama2 import ModelArgs, tp_llama
args = ModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
vocab_size=32000,
max_seq_len=2048,
)
# Create a tensor-parallel Llama-2-70B across 8 GPUs
llama = tp_llama(
model_args=args,
checkpoint_path="/models/llama-2-70b/consolidated.pth",
tokenizer_path="/models/llama-2-70b/tokenizer.model",
world_size=8,
)
Example 3: Understanding the forward pass
import torch
from examples.large_models.tp_llama.llama2 import (
ModelArgs, Transformer, precompute_freqs_cis
)
args = ModelArgs(dim=4096, n_layers=32, n_heads=32, vocab_size=32000)
model = Transformer(args).to("cuda")
# Input: batch of 2 sequences, each 10 tokens
tokens = torch.randint(0, 32000, (2, 10), device="cuda")
# Forward pass at position 0 (first generation step)
logits = model(tokens, start_pos=0)
# logits shape: (2, 10, 32000)
# Next token prediction: take logits at the last position
next_token_logits = logits[:, -1, :] # shape: (2, 32000)
next_tokens = torch.argmax(next_token_logits, dim=-1) # shape: (2,)
Related Pages