Implementation:Tensorflow Tfjs GPT2CausalLM Constructor
Summary
GPT2CausalLM is the causal language model class for GPT-2 in TensorFlow.js. It attaches a language modeling head to a GPT2Backbone for next-token prediction, using ReverseEmbedding for weight tying (reusing the token embedding weights transposed to project hidden states back to vocabulary logits).
API
new GPT2CausalLM(args: GPT2CausalLMArgs) + ReverseEmbedding
Source
tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L159-224(GPT2CausalLM)tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L40-56(ReverseEmbedding)
Type
API Doc
Signatures
GPT2CausalLM
interface GPT2CausalLMArgs {
backbone: GPT2Backbone;
preprocessor?: GPT2Preprocessor;
}
class GPT2CausalLM extends GenerativeTask {
constructor(args: GPT2CausalLMArgs)
callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor]
generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap
}
ReverseEmbedding (Weight Tying)
interface ReverseEmbeddingArgs extends LayerArgs {
embedding: Embedding;
}
class ReverseEmbedding extends Layer {
call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[]
}
Constructor Parameters
GPT2CausalLMArgs
| Parameter | Type | Required | Description |
|---|---|---|---|
backbone |
GPT2Backbone |
Yes | The GPT-2 backbone providing contextualized hidden states |
preprocessor |
GPT2Preprocessor |
No | Optional preprocessor for tokenizing and packing input text |
ReverseEmbeddingArgs
| Parameter | Type | Required | Description |
|---|---|---|---|
embedding |
Embedding |
Yes | The token embedding layer whose weights are reused (transposed) for projection |
Methods
| Method | Signature | Description |
|---|---|---|
callWithCache |
callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor] |
Forward pass with KV cache for efficient autoregressive generation |
generateStep |
generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap |
Single step of autoregressive generation |
I/O
- Inputs:
GPT2Backboneinstance + optionalGPT2Preprocessor - Outputs: A
GPT2CausalLMmodel with the following capabilities:callWithCache(): Returns [logits, hidden_states, updated_cache]generateStep(): Returns updatedNamedTensorMapwith generated token IDsbuildCache(): Initializes the KV cache for autoregressive generation
ReverseEmbedding Details
The ReverseEmbedding layer implements weight tying by taking a reference to the backbone's token embedding layer. During the forward pass, it computes:
logits = hidden_states × embedding.weightsT
This avoids introducing additional parameters for the output projection, since the transposed embedding weights serve as the projection matrix.
Example
const causalLM = new GPT2CausalLM({
backbone: backbone,
preprocessor: preprocessor,
});
Implements
Principle:Tensorflow_Tfjs_Causal_Language_Model_Head
Environment:Tensorflow_Tfjs_Browser_Runtime
Domains
Sources
Related Pages
Environments
- Environment:Tensorflow_Tfjs_Browser_Runtime -- Browser runtime (WebGL / WebGPU / WASM / CPU backends)