Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Llama2 Model

From Leeroopedia
Revision as of 13:46, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pytorch_Serve_Llama2_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains LLM_Architecture, Tensor_Parallelism, Model_Serving
Last Updated 2026-02-13 18:52 GMT

Overview

Llama2_Model implements the full Llama2 transformer architecture for TorchServe with distributed tensor parallelism support. It defines the complete model hierarchy from low-level components (RMSNorm, rotary embeddings) through attention and feed-forward blocks to the full Transformer and high-level Llama wrapper class. The module includes parallelization functions for distributing the model across multiple GPUs.

Description

The llama2.py module contains the complete Llama2 model implementation organized as a layered class hierarchy. Each class encapsulates a specific component of the transformer architecture, and utility functions provide rotary positional embeddings and tensor parallelism support.

Architecture Hierarchy

  • ModelArgs (dataclass): Configuration parameters for the model (dimensions, layers, heads, vocabulary size)
  • RMSNorm: Root Mean Square Layer Normalization (replacing LayerNorm)
  • Attention: Multi-head attention with rotary positional embeddings (RoPE), key-value caching, and grouped-query attention (GQA) via repeat_kv()
  • FeedForward: SwiGLU feed-forward network with gated linear units
  • TransformerBlock: Single transformer layer combining Attention + FeedForward + RMSNorm
  • Transformer: Full transformer stack with token embedding, positional encoding, and output projection
  • Llama: High-level wrapper that combines Transformer with tokenizer and generation logic

Parallelism Support

  • tp_llama(): Entry point for creating a tensor-parallel Llama model
  • parallelize_llama_attn_block(): Distributes attention and feed-forward weights across GPUs using column and row parallel linear layers

Usage

from examples.large_models.tp_llama.llama2 import (
    ModelArgs,
    Transformer,
    Llama,
    tp_llama,
)

Code Reference

Source Location

File Lines Repository
examples/large_models/tp_llama/llama2.py L1-701 pytorch/serve
examples/large_models/tp_llama/llama2.py L43-55 ModelArgs dataclass
examples/large_models/tp_llama/llama2.py L59-70 RMSNorm class
examples/large_models/tp_llama/llama2.py L111-210 Attention class
examples/large_models/tp_llama/llama2.py L213-236 FeedForward class
examples/large_models/tp_llama/llama2.py L239-278 TransformerBlock class
examples/large_models/tp_llama/llama2.py L281-391 Transformer class
examples/large_models/tp_llama/llama2.py L611-701 Llama class

Signature

@dataclass
class ModelArgs:
    """
    Configuration dataclass for Llama2 model architecture.

    Attributes:
        dim (int): Model embedding dimension.
        n_layers (int): Number of transformer layers.
        n_heads (int): Number of attention heads.
        n_kv_heads (int|None): Number of key-value heads for GQA (None = MHA).
        vocab_size (int): Vocabulary size (-1 for dynamic).
        multiple_of (int): FFN hidden dim rounded to this multiple.
        ffn_dim_multiplier (float|None): Optional FFN dimension scaling factor.
        norm_eps (float): Epsilon for RMSNorm.
        max_batch_size (int): Maximum batch size for KV cache allocation.
        max_seq_len (int): Maximum sequence length.
    """
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: int | None = None
    vocab_size: int = -1
    multiple_of: int = 256
    ffn_dim_multiplier: float | None = None
    norm_eps: float = 1e-5
    max_batch_size: int = 32
    max_seq_len: int = 2048


class RMSNorm(torch.nn.Module):
    """
    Root Mean Square Layer Normalization.

    Normalizes inputs by their RMS value and applies a learnable scale.
    More efficient than LayerNorm as it omits mean subtraction.

    Args:
        dim (int): Feature dimension to normalize.
        eps (float): Small constant for numerical stability.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        ...

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ...


class Attention(torch.nn.Module):
    """
    Multi-head attention with rotary positional embeddings.

    Supports grouped-query attention (GQA) where n_kv_heads < n_heads.
    Uses pre-allocated KV cache for efficient autoregressive generation.

    Args:
        args (ModelArgs): Model configuration.
    """

    def __init__(self, args: ModelArgs):
        ...

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: torch.Tensor | None,
    ) -> torch.Tensor:
        """
        Compute attention output.

        Args:
            x (torch.Tensor): Input tensor of shape (batch, seq_len, dim).
            start_pos (int): Current position in the sequence (for KV cache).
            freqs_cis (torch.Tensor): Precomputed rotary embedding frequencies.
            mask (torch.Tensor|None): Causal attention mask.

        Returns:
            torch.Tensor: Attention output of shape (batch, seq_len, dim).
        """
        ...


class FeedForward(torch.nn.Module):
    """
    SwiGLU feed-forward network.

    Uses gated linear units with SiLU activation:
    output = (silu(w1(x)) * w3(x)) @ w2

    Args:
        dim (int): Input/output dimension.
        hidden_dim (int): Intermediate hidden dimension.
        multiple_of (int): Round hidden_dim up to this multiple.
        ffn_dim_multiplier (float|None): Optional hidden dim scaling.
    """

    def __init__(self, dim: int, hidden_dim: int, multiple_of: int,
                 ffn_dim_multiplier: float | None):
        ...

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ...


class TransformerBlock(torch.nn.Module):
    """
    Single transformer layer: Attention + FeedForward with residual connections.

    Architecture:
        x -> RMSNorm -> Attention -> + (residual) -> RMSNorm -> FeedForward -> + (residual)

    Args:
        layer_id (int): Layer index.
        args (ModelArgs): Model configuration.
    """

    def __init__(self, layer_id: int, args: ModelArgs):
        ...

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: torch.Tensor | None,
    ) -> torch.Tensor:
        ...


class Transformer(torch.nn.Module):
    """
    Full Llama2 transformer model.

    Combines token embedding, positional encoding (RoPE), stacked
    TransformerBlocks, final RMSNorm, and linear output projection.

    Args:
        params (ModelArgs): Model configuration dataclass.
    """

    def __init__(self, params: ModelArgs):
        ...

    def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:
        """
        Forward pass through the full transformer.

        Args:
            tokens (torch.Tensor): Input token IDs of shape (batch, seq_len).
            start_pos (int): Current position for KV cache indexing.

        Returns:
            torch.Tensor: Logits of shape (batch, seq_len, vocab_size).
        """
        ...


class Llama:
    """
    High-level Llama2 wrapper combining Transformer, tokenizer, and generation.

    Provides text_completion() and chat_completion() methods that handle
    tokenization, generation, and decoding.

    Args:
        model (Transformer): The transformer model instance.
        tokenizer: SentencePiece tokenizer.
    """

    def __init__(self, model: Transformer, tokenizer):
        ...


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """
    Precompute rotary embedding frequency tensor.

    Args:
        dim (int): Head dimension (dim // n_heads).
        end (int): Maximum sequence length.
        theta (float): Base frequency (default 10000.0).

    Returns:
        torch.Tensor: Complex frequency tensor of shape (end, dim // 2).
    """
    ...


def apply_rotary_emb(
    xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary positional embeddings to query and key tensors.

    Args:
        xq (torch.Tensor): Query tensor.
        xk (torch.Tensor): Key tensor.
        freqs_cis (torch.Tensor): Precomputed frequency tensor.

    Returns:
        tuple: Rotated (query, key) tensors.
    """
    ...


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    Repeat key-value heads for grouped-query attention.

    When n_kv_heads < n_heads, each KV head is repeated n_rep times
    to match the number of query heads.

    Args:
        x (torch.Tensor): KV tensor of shape (batch, n_kv_heads, seq_len, head_dim).
        n_rep (int): Number of times to repeat each KV head.

    Returns:
        torch.Tensor: Expanded tensor of shape (batch, n_heads, seq_len, head_dim).
    """
    ...


def tp_llama(model_args, checkpoint_path, tokenizer_path, world_size):
    """
    Create a tensor-parallel Llama model distributed across multiple GPUs.

    Args:
        model_args (ModelArgs): Model configuration.
        checkpoint_path (str): Path to model checkpoint.
        tokenizer_path (str): Path to tokenizer model file.
        world_size (int): Number of GPUs for tensor parallelism.

    Returns:
        Llama: Distributed Llama model instance.
    """
    ...


def parallelize_llama_attn_block(model, world_size):
    """
    Distribute attention and feed-forward weights across GPUs.

    Converts linear layers to column-parallel and row-parallel variants
    for tensor parallelism.

    Args:
        model (Transformer): The transformer model to parallelize.
        world_size (int): Number of parallel GPUs.
    """
    ...

Import

import torch
from dataclasses import dataclass
from typing import Optional

from examples.large_models.tp_llama.llama2 import (
    ModelArgs,
    RMSNorm,
    Attention,
    FeedForward,
    TransformerBlock,
    Transformer,
    Llama,
    precompute_freqs_cis,
    apply_rotary_emb,
    repeat_kv,
    tp_llama,
    parallelize_llama_attn_block,
)

I/O Contract

Class / Function Input Output Notes
ModelArgs() Keyword arguments for model dimensions ModelArgs dataclass instance Lines 43-55; all fields have defaults
RMSNorm(dim, eps) dim: feature dimension; eps: stability constant Module; forward takes torch.Tensor, returns torch.Tensor Lines 59-70
Attention(args).forward(x, start_pos, freqs_cis, mask) Input tensor (batch, seq_len, dim), position, frequencies, mask torch.Tensor (batch, seq_len, dim) Lines 111-210; uses KV cache
FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier).forward(x) torch.Tensor (batch, seq_len, dim) torch.Tensor (batch, seq_len, dim) Lines 213-236; SwiGLU activation
TransformerBlock(layer_id, args).forward(x, start_pos, freqs_cis, mask) Input tensor, position, frequencies, mask torch.Tensor (batch, seq_len, dim) Lines 239-278; Attention + FFN + residual
Transformer(params).forward(tokens, start_pos) Token IDs (batch, seq_len), start position Logits (batch, seq_len, vocab_size) Lines 281-391; full model forward
Llama(model, tokenizer) Transformer model, SentencePiece tokenizer Llama wrapper instance Lines 611-701; high-level API
precompute_freqs_cis(dim, end, theta) Head dim, max seq len, base frequency Complex tensor (end, dim//2) Rotary embedding frequencies
apply_rotary_emb(xq, xk, freqs_cis) Query tensor, key tensor, frequency tensor Tuple of rotated (query, key) tensors Applies RoPE to Q and K
repeat_kv(x, n_rep) KV tensor (batch, n_kv_heads, seq, head_dim), repetition count Expanded tensor (batch, n_heads, seq, head_dim) For grouped-query attention
tp_llama(model_args, checkpoint_path, tokenizer_path, world_size) Model config, paths, GPU count Distributed Llama instance Creates tensor-parallel model
parallelize_llama_attn_block(model, world_size) Transformer model, GPU count None (modifies model in-place) Column/row parallel conversion

Model Configurations

Model dim n_layers n_heads n_kv_heads vocab_size max_seq_len
Llama-2-7B 4096 32 32 32 32000 2048
Llama-2-13B 5120 40 40 40 32000 2048
Llama-2-70B 8192 80 64 8 32000 2048

Usage Examples

Example 1: Creating a Llama2-7B model

from examples.large_models.tp_llama.llama2 import ModelArgs, Transformer

args = ModelArgs(
    dim=4096,
    n_layers=32,
    n_heads=32,
    n_kv_heads=32,
    vocab_size=32000,
    max_seq_len=2048,
    max_batch_size=4,
)

model = Transformer(args)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Example 2: Tensor-parallel model creation

from examples.large_models.tp_llama.llama2 import ModelArgs, tp_llama

args = ModelArgs(
    dim=8192,
    n_layers=80,
    n_heads=64,
    n_kv_heads=8,
    vocab_size=32000,
    max_seq_len=2048,
)

# Create a tensor-parallel Llama-2-70B across 8 GPUs
llama = tp_llama(
    model_args=args,
    checkpoint_path="/models/llama-2-70b/consolidated.pth",
    tokenizer_path="/models/llama-2-70b/tokenizer.model",
    world_size=8,
)

Example 3: Understanding the forward pass

import torch
from examples.large_models.tp_llama.llama2 import (
    ModelArgs, Transformer, precompute_freqs_cis
)

args = ModelArgs(dim=4096, n_layers=32, n_heads=32, vocab_size=32000)
model = Transformer(args).to("cuda")

# Input: batch of 2 sequences, each 10 tokens
tokens = torch.randint(0, 32000, (2, 10), device="cuda")

# Forward pass at position 0 (first generation step)
logits = model(tokens, start_pos=0)
# logits shape: (2, 10, 32000)

# Next token prediction: take logits at the last position
next_token_logits = logits[:, -1, :]  # shape: (2, 32000)
next_tokens = torch.argmax(next_token_logits, dim=-1)  # shape: (2,)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment