Principle:Tensorflow Tfjs Causal Language Model Head
Summary
Causal language model head refers to the output projection layer that converts a transformer backbone's hidden states into next-token predictions. This is a library-agnostic concept: a causal LM head projects hidden representations back to vocabulary logits, typically using weight tying (reusing the token embedding weights transposed) to reduce parameters and improve performance.
Theory
The causal language model head is the final component in an autoregressive language model, converting contextualized hidden states into token probabilities.
The process works as follows:
- Input: The backbone produces final hidden states of shape [batch, seq_len, hidden_dim].
- Projection: The hidden states are projected to vocabulary logits via a reverse embedding (weight tying):
- logits = hidden_states × embedding_weightsT
- Output: Logits of shape [batch, seq_len, vocab_size] where each position i predicts the token at position i+1.
Weight Tying
Weight tying is a technique where the output projection matrix shares weights with the input token embedding matrix (transposed). Given the token embedding matrix E of shape [vocab_size, hidden_dim]:
- Embedding: token_id → E[token_id] (row lookup)
- Reverse Embedding: hidden_state → hidden_state × ET (matrix multiplication)
| Aspect | Without Weight Tying | With Weight Tying |
|---|---|---|
| Parameters | Separate embedding + projection matrices | Single shared embedding matrix |
| Parameter Count | 2 × V × dmodel | V × dmodel |
| Performance | Baseline | Generally improved (regularization effect) |
Causal Masking
The term causal refers to the constraint that each position's prediction depends only on previous positions. The backbone enforces this via causal attention masking, and the LM head simply projects the already-causally-constrained hidden states to logits.
Key Properties
- Weight tying: Reduces model parameters by approximately V × dmodel (e.g., ~38M parameters saved for GPT-2).
- Autoregressive: Each position predicts the next token, enabling left-to-right text generation.
- No additional learnable parameters: The reverse embedding introduces no new weights beyond those already in the embedding layer.
- Softmax-ready: Output logits can be converted to probabilities via softmax for sampling or loss computation.
Implementation
Implementation:Tensorflow_Tfjs_GPT2CausalLM_Constructor