Implementation:VainF Torch Pruning Eval PPL
| Property | Value |
|---|---|
| Source | Torch-Pruning|https://github.com/VainF/Torch-Pruning |
| Domains | NLP, Evaluation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for evaluating LLM perplexity on WikiText-2 provided by the Torch-Pruning examples.
Description
eval_ppl orchestrates perplexity evaluation: it loads WikiText-2 via get_wikitext2(), tokenizes it, then computes perplexity via eval_ppl_wikitext() which processes the text in non-overlapping windows of model.seqlen tokens, computing cross-entropy loss at each position.
The evaluation pipeline consists of three cooperating functions:
get_wikitext2(nsamples, seed, seqlen, tokenizer)-- Loads the WikiText-2 dataset (train and test splits), tokenizes both into continuous token sequences, and creates random training samples of lengthseqlen. Returns a training loader and the full tokenized test encoding.eval_ppl(args, model, tokenizer, device)-- Top-level entry point. Callsget_loaders()to obtain the test encoding, then delegates toeval_ppl_wikitext()inside atorch.no_grad()context.eval_ppl_wikitext(model, testenc, bs, device)-- The core computation loop. Splits the test encoding intonsamples = total_tokens // model.seqlennon-overlapping windows. For each window, runs a forward pass, computes shifted cross-entropy loss (logits[:, :-1] vs labels[:, 1:]), accumulates negative log-likelihoods, and exponentiates the mean NLL to produce the final perplexity.
Code Reference
Source files:
examples/LLMs/prune_llm.py, lines 91-106 (eval_ppl)examples/LLMs/prune_llm.py, lines 34-53 (get_wikitext2)examples/LLMs/prune_llm.py, lines 160-206 (eval_ppl_wikitext)examples/LLMs/eval_ppl.py(standalone evaluation script with identical functions)
Import: These are in example scripts, not the core library. Copy/adapt from examples/LLMs/prune_llm.py or examples/LLMs/eval_ppl.py.
Function Signatures
def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
"""Evaluate perplexity on WikiText-2.
Loads the WikiText-2 test set, tokenizes it using the provided tokenizer,
and computes perplexity using non-overlapping windows of model.seqlen tokens.
"""
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
"""Load and tokenize WikiText-2.
Returns:
trainloader: List of (input_ids, targets) tuples for calibration.
testenc: TokenizerWrapper containing full tokenized test set.
"""
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
"""Compute perplexity on tokenized WikiText-2.
Processes the test encoding in non-overlapping windows of model.seqlen
tokens, computing cross-entropy loss at each position.
Returns:
ppl (float): Perplexity on the WikiText-2 test set.
"""
Core Computation (eval_ppl_wikitext)
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
testenc = testenc.input_ids
nsamples = testenc.numel() // model.seqlen
nlls = []
for i in range(0, nsamples, bs):
j = min(i + bs, nsamples)
inputs = testenc[:, (i * model.seqlen):(j * model.seqlen)].to(device)
inputs = inputs.reshape(j - i, model.seqlen)
lm_logits = model(inputs).logits
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inputs[:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.reshape(-1),
)
neg_log_likelihood = loss.float() * model.seqlen * (j - i)
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
torch.cuda.empty_cache()
return ppl.item()
I/O Contract
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | model | nn.Module | Language model with a .seqlen attribute (int) defining window size, and a .logits output
|
| Input | tokenizer | PreTrainedTokenizer | HuggingFace tokenizer for encoding the WikiText-2 text |
| Input | nsamples | int | Number of random calibration samples to extract from training set (default: 128) |
| Input | seed | int | Random seed for reproducible sample selection (default: 0) |
| Input | seqlen | int | Sequence length per window, typically model.seqlen (e.g., 2048 or 4096)
|
| Output | ppl_test | float | Perplexity on the WikiText-2 test set (lower is better) |
Usage Examples
Evaluating a Pruned LLM
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load pruned model (or original for baseline)
model = AutoModelForCausalLM.from_pretrained(
"pruned_llama_model/",
torch_dtype=torch.float16,
device_map="auto",
)
model.seqlen = min(4096, model.config.max_position_embeddings)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=False)
device = torch.device("cuda:0")
# Copy eval_ppl and dependencies from examples/LLMs/prune_llm.py
ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"WikiText-2 perplexity: {ppl_test:.2f}")
Full Prune-and-Evaluate Workflow
From examples/LLMs/prune_llm.py (lines 378-385):
# After pruning is complete...
del pruner
torch.cuda.empty_cache()
model.eval()
num_params = sum(p.numel() for p in model.parameters())
print(f"num_params {num_params}")
ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"wikitext perplexity {ppl_test}")