Principle:Tensorflow Tfjs Autoregressive Text Generation
Summary
Autoregressive text generation produces text token-by-token using a causal language model. This is a library-agnostic concept: autoregressive generation produces one token at a time, feeding each generated token back as input for the next step, until a stop condition is met.
Theory
Autoregressive generation is the standard method for producing text from decoder-only transformer models like GPT-2. The generation loop produces one token per iteration, conditioning each new prediction on all previously generated tokens.
The generation process follows these steps:
- Initialize KV cache for efficient self-attention computation.
- Forward pass through the model to obtain logits for the next token position.
- Sample or select the next token from the logits using a decoding strategy (greedy, top-k, nucleus sampling).
- Append the generated token to the input sequence.
- Update KV cache to avoid recomputation of previous positions.
- Repeat until an end token is generated or the maximum length is reached.
KV Caching
KV (Key-Value) caching is a critical optimization for autoregressive generation. Without caching, generating a sequence of length n requires O(n^2) total computation because each step recomputes attention over all previous positions. With KV caching, each step only computes attention for the new token, reducing per-step cost to O(n) and total generation cost from O(n^3) to O(n^2).
| Aspect | Without KV Cache | With KV Cache |
|---|---|---|
| Per-step computation | Recompute all positions | Compute only new position |
| Per-step complexity | O(n) | O(1) for new token (O(n) for attention over cached keys) |
| Total generation complexity | O(n^3) | O(n^2) |
| Memory usage | Lower (recomputed each step) | Higher (cached K,V tensors stored) |
Decoding Strategies
Several strategies exist for selecting the next token from the logit distribution:
| Strategy | Description | Properties |
|---|---|---|
| Greedy | Select the token with the highest logit | Deterministic, fast, may produce repetitive text |
| Top-k | Sample from the top k highest-probability tokens | Balances diversity and quality |
| Nucleus (Top-p) | Sample from the smallest set of tokens whose cumulative probability exceeds p | Adaptive vocabulary size per step |
| Temperature | Scale logits by 1/T before softmax; T < 1 sharpens, T > 1 flattens | Controls randomness of sampling |
Stop Conditions
Generation terminates when any of the following conditions is met:
- The model generates the end-of-sequence token (e.g.,
<|endoftext|>for GPT-2). - The generated sequence reaches the maximum length limit.
- An application-specific stopping criterion is triggered.
Key Properties
- Sequential by nature: Each token depends on all previous tokens, limiting parallelization during generation.
- KV cache trade-off: Trading memory for computation enables practical generation speeds.
- Decoding strategy matters: The choice of sampling method significantly impacts output quality and diversity.
- Prompt-conditioned: The initial prompt tokens are processed in parallel (prefill phase), then generation proceeds autoregressively.
Implementation
Implementation:Tensorflow_Tfjs_GPT2CausalLM_Generate