Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mit han lab Llm awq Wikitext eval loop

From Leeroopedia

Overview

Inline evaluation pattern for computing WikiText-2 perplexity on quantized models in the llm-awq library.

Source

awq/entry.py, Lines 301-333

Doc Type

This is a Pattern Doc: the evaluation is an inline code block in awq/entry.py, not a standalone function.

Pattern

# Load WikiText-2 test data
testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt")
model.seqlen = 2048
testenc = testenc.input_ids.to(model.device)
nsamples = testenc.numel() // model.seqlen

# Evaluate sliding windows
nlls = []
for i in range(nsamples):
    batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)]
    with torch.no_grad():
        lm_logits = model(batch).logits
    shift_logits = lm_logits[:, :-1, :].contiguous().float()
    shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
    loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    nlls.append(loss.float() * model.seqlen)

ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

Import

from datasets import load_dataset
import torch.nn as nn

I/O

Inputs:

  • Quantized model (on GPU) with a .device attribute and forward pass returning .logits
  • Tokenizer (enc) capable of encoding text to PyTorch tensors

Output:

  • PPL scalar (float) - the perplexity value
  • Saved as {"ppl": float} to a JSON results file

How It Works

  1. WikiText-2 test split is loaded and all text is joined with double newlines
  2. The full text is tokenized into a single flat tensor of token IDs
  3. The tensor is split into nsamples non-overlapping windows of 2048 tokens each
  4. For each window, the model's logits are computed via forward pass
  5. Logits are shifted by one position (lm_logits[:, :-1, :]) to align predictions with targets
  6. Cross-entropy loss is computed and scaled by model.seqlen
  7. Final PPL is computed as exp(sum(nlls) / total_tokens)

Related Pages

Knowledge Sources

Domains

  • NLP
  • Evaluation

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment