Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Tensorflow Tfjs GPT2CausalLM Constructor

From Leeroopedia


Summary

GPT2CausalLM is the causal language model class for GPT-2 in TensorFlow.js. It attaches a language modeling head to a GPT2Backbone for next-token prediction, using ReverseEmbedding for weight tying (reusing the token embedding weights transposed to project hidden states back to vocabulary logits).

API

new GPT2CausalLM(args: GPT2CausalLMArgs) + ReverseEmbedding

Source

  • tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L159-224 (GPT2CausalLM)
  • tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L40-56 (ReverseEmbedding)

Type

API Doc

Signatures

GPT2CausalLM

interface GPT2CausalLMArgs {
  backbone: GPT2Backbone;
  preprocessor?: GPT2Preprocessor;
}

class GPT2CausalLM extends GenerativeTask {
  constructor(args: GPT2CausalLMArgs)
  callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor]
  generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap
}

ReverseEmbedding (Weight Tying)

interface ReverseEmbeddingArgs extends LayerArgs {
  embedding: Embedding;
}

class ReverseEmbedding extends Layer {
  call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[]
}

Constructor Parameters

GPT2CausalLMArgs

Parameter Type Required Description
backbone GPT2Backbone Yes The GPT-2 backbone providing contextualized hidden states
preprocessor GPT2Preprocessor No Optional preprocessor for tokenizing and packing input text

ReverseEmbeddingArgs

Parameter Type Required Description
embedding Embedding Yes The token embedding layer whose weights are reused (transposed) for projection

Methods

Method Signature Description
callWithCache callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor] Forward pass with KV cache for efficient autoregressive generation
generateStep generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap Single step of autoregressive generation

I/O

  • Inputs: GPT2Backbone instance + optional GPT2Preprocessor
  • Outputs: A GPT2CausalLM model with the following capabilities:
    • callWithCache(): Returns [logits, hidden_states, updated_cache]
    • generateStep(): Returns updated NamedTensorMap with generated token IDs
    • buildCache(): Initializes the KV cache for autoregressive generation

ReverseEmbedding Details

The ReverseEmbedding layer implements weight tying by taking a reference to the backbone's token embedding layer. During the forward pass, it computes:

logits = hidden_states × embedding.weightsT

This avoids introducing additional parameters for the output projection, since the transposed embedding weights serve as the projection matrix.

Example

const causalLM = new GPT2CausalLM({
  backbone: backbone,
  preprocessor: preprocessor,
});

Implements

Principle:Tensorflow_Tfjs_Causal_Language_Model_Head

Environment:Tensorflow_Tfjs_Browser_Runtime

Domains

NLP Language_Modeling

Sources

TensorFlow.js

Related Pages

Environments

Metadata

2026-02-10 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment