Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Tensorflow Tfjs GPT2CausalLM Generate

From Leeroopedia


Summary

GPT2CausalLM.generate and GPT2CausalLM.generateStep implement autoregressive text generation for GPT-2 in TensorFlow.js. The generate method (inherited from GenerativeTask) orchestrates the generation loop, while generateStep performs a single forward pass with KV caching. callWithCache provides the low-level cached forward pass through the backbone and LM head.

Note: generate() throws NotImplementedError in the base GenerativeTask class (L113-115). This is an in-progress NLP module.

API

GPT2CausalLM.generate(inputs: Tensor, maxLength?: number) (inherited from GenerativeTask) + GPT2CausalLM.generateStep(inputs, endTokenId): NamedTensorMap

Source

  • tfjs-layers/src/layers/nlp/models/generative_task.ts:L36-116 (GenerativeTask with generate at L113-115)
  • tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L217-222 (generateStep)
  • tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L191-197 (callWithCache)

Type

API Doc

Signatures

GenerativeTask Base Class

class GenerativeTask extends Task {
  generate(inputs: Tensor, maxLength?: number): void
  generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap
  makeGenerateFunction(): GenerateFn
}

GPT2CausalLM Overrides

class GPT2CausalLM extends GenerativeTask {
  generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap
  callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor]
  private buildCache(tokenIds: Tensor): [Tensor, Tensor]
}

Methods

Method Signature Description
generate generate(inputs: Tensor, maxLength?: number): void Orchestrates the autoregressive generation loop. Throws NotImplementedError in base class.
generateStep generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap Performs a single autoregressive generation step: forward pass with KV cache, token selection, and cache update
callWithCache callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor] Low-level cached forward pass returning [logits, hidden_states, updated_cache]
buildCache private buildCache(tokenIds: Tensor): [Tensor, Tensor] Initializes the KV cache tensors for the prompt (prefill phase)
makeGenerateFunction makeGenerateFunction(): GenerateFn Creates a callable generation function from the model

Generation Flow

The autoregressive generation process involves these steps:

  1. Build cache: buildCache(tokenIds) initializes KV cache by running the prompt through the backbone.
  2. Generate step: generateStep(inputs, endTokenId) performs one iteration:
    1. Calls callWithCache(tokenIds, cache, cacheUpdateIndex) for a forward pass.
    2. Selects the next token from the output logits.
    3. Updates the cache with the new key-value pairs.
    4. Returns updated NamedTensorMap with the new token appended.
  3. Repeat: Steps are repeated until the end token is generated or maxLength is reached.

I/O

  • Inputs: Preprocessed token IDs as Tensor, maximum generation length
  • Outputs: Generated token IDs as Tensor; decode to text via tokenizer.detokenize()

callWithCache Details

Parameter Type Description
tokenIds Tensor Token IDs for the current generation step (single token during generation, full prompt during prefill)
cache Tensor The KV cache tensor from previous steps
cacheUpdateIndex number The position index at which to write the new KV entries in the cache

Returns a tuple of:

  • Logits: Tensor of shape [batch, 1, vocab_size] (next-token logits)
  • Hidden states: Tensor of the backbone's output
  • Updated cache: Tensor with new KV entries written at cacheUpdateIndex

Example

// Generate text (conceptual - base class throws NotImplementedError)
const inputTokens = preprocessor.call(tf.tensor1d(['The future of AI'], 'string'));
const generated = causalLM.generate(inputTokens, 50);
const text = tokenizer.detokenize(generated);

Implements

Principle:Tensorflow_Tfjs_Autoregressive_Text_Generation

Environment:Tensorflow_Tfjs_Browser_Runtime Heuristic:Tensorflow_Tfjs_Memory_Management_With_Tidy

Domains

NLP Text_Generation

Sources

TensorFlow.js

Related Pages

Environments

Heuristics

Metadata

2026-02-10 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment