Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning Eval PPL

From Leeroopedia


Property Value
Source Torch-Pruning|https://github.com/VainF/Torch-Pruning
Domains NLP, Evaluation
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for evaluating LLM perplexity on WikiText-2 provided by the Torch-Pruning examples.

Description

eval_ppl orchestrates perplexity evaluation: it loads WikiText-2 via get_wikitext2(), tokenizes it, then computes perplexity via eval_ppl_wikitext() which processes the text in non-overlapping windows of model.seqlen tokens, computing cross-entropy loss at each position.

The evaluation pipeline consists of three cooperating functions:

  1. get_wikitext2(nsamples, seed, seqlen, tokenizer) -- Loads the WikiText-2 dataset (train and test splits), tokenizes both into continuous token sequences, and creates random training samples of length seqlen. Returns a training loader and the full tokenized test encoding.
  2. eval_ppl(args, model, tokenizer, device) -- Top-level entry point. Calls get_loaders() to obtain the test encoding, then delegates to eval_ppl_wikitext() inside a torch.no_grad() context.
  3. eval_ppl_wikitext(model, testenc, bs, device) -- The core computation loop. Splits the test encoding into nsamples = total_tokens // model.seqlen non-overlapping windows. For each window, runs a forward pass, computes shifted cross-entropy loss (logits[:, :-1] vs labels[:, 1:]), accumulates negative log-likelihoods, and exponentiates the mean NLL to produce the final perplexity.

Code Reference

Source files:

  • examples/LLMs/prune_llm.py, lines 91-106 (eval_ppl)
  • examples/LLMs/prune_llm.py, lines 34-53 (get_wikitext2)
  • examples/LLMs/prune_llm.py, lines 160-206 (eval_ppl_wikitext)
  • examples/LLMs/eval_ppl.py (standalone evaluation script with identical functions)

Import: These are in example scripts, not the core library. Copy/adapt from examples/LLMs/prune_llm.py or examples/LLMs/eval_ppl.py.

Function Signatures

def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
    """Evaluate perplexity on WikiText-2.

    Loads the WikiText-2 test set, tokenizes it using the provided tokenizer,
    and computes perplexity using non-overlapping windows of model.seqlen tokens.
    """

def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    """Load and tokenize WikiText-2.

    Returns:
        trainloader: List of (input_ids, targets) tuples for calibration.
        testenc: TokenizerWrapper containing full tokenized test set.
    """

def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    """Compute perplexity on tokenized WikiText-2.

    Processes the test encoding in non-overlapping windows of model.seqlen
    tokens, computing cross-entropy loss at each position.

    Returns:
        ppl (float): Perplexity on the WikiText-2 test set.
    """

Core Computation (eval_ppl_wikitext)

def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen
    nlls = []

    for i in range(0, nsamples, bs):
        j = min(i + bs, nsamples)
        inputs = testenc[:, (i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j - i, model.seqlen)

        lm_logits = model(inputs).logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
        )

        neg_log_likelihood = loss.float() * model.seqlen * (j - i)
        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    torch.cuda.empty_cache()
    return ppl.item()

I/O Contract

Direction Name Type Description
Input model nn.Module Language model with a .seqlen attribute (int) defining window size, and a .logits output
Input tokenizer PreTrainedTokenizer HuggingFace tokenizer for encoding the WikiText-2 text
Input nsamples int Number of random calibration samples to extract from training set (default: 128)
Input seed int Random seed for reproducible sample selection (default: 0)
Input seqlen int Sequence length per window, typically model.seqlen (e.g., 2048 or 4096)
Output ppl_test float Perplexity on the WikiText-2 test set (lower is better)

Usage Examples

Evaluating a Pruned LLM

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load pruned model (or original for baseline)
model = AutoModelForCausalLM.from_pretrained(
    "pruned_llama_model/",
    torch_dtype=torch.float16,
    device_map="auto",
)
model.seqlen = min(4096, model.config.max_position_embeddings)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=False)
device = torch.device("cuda:0")

# Copy eval_ppl and dependencies from examples/LLMs/prune_llm.py
ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"WikiText-2 perplexity: {ppl_test:.2f}")

Full Prune-and-Evaluate Workflow

From examples/LLMs/prune_llm.py (lines 378-385):

# After pruning is complete...
del pruner
torch.cuda.empty_cache()
model.eval()

num_params = sum(p.numel() for p in model.parameters())
print(f"num_params {num_params}")

ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"wikitext perplexity {ppl_test}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment