Implementation:Tensorflow Tfjs GPT2CausalLM Generate
Summary
GPT2CausalLM.generate and GPT2CausalLM.generateStep implement autoregressive text generation for GPT-2 in TensorFlow.js. The generate method (inherited from GenerativeTask) orchestrates the generation loop, while generateStep performs a single forward pass with KV caching. callWithCache provides the low-level cached forward pass through the backbone and LM head.
Note: generate() throws NotImplementedError in the base GenerativeTask class (L113-115). This is an in-progress NLP module.
API
GPT2CausalLM.generate(inputs: Tensor, maxLength?: number) (inherited from GenerativeTask) + GPT2CausalLM.generateStep(inputs, endTokenId): NamedTensorMap
Source
tfjs-layers/src/layers/nlp/models/generative_task.ts:L36-116(GenerativeTask withgenerateat L113-115)tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L217-222(generateStep)tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts:L191-197(callWithCache)
Type
API Doc
Signatures
GenerativeTask Base Class
class GenerativeTask extends Task {
generate(inputs: Tensor, maxLength?: number): void
generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap
makeGenerateFunction(): GenerateFn
}
GPT2CausalLM Overrides
class GPT2CausalLM extends GenerativeTask {
generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap
callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor]
private buildCache(tokenIds: Tensor): [Tensor, Tensor]
}
Methods
| Method | Signature | Description |
|---|---|---|
generate |
generate(inputs: Tensor, maxLength?: number): void |
Orchestrates the autoregressive generation loop. Throws NotImplementedError in base class.
|
generateStep |
generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap |
Performs a single autoregressive generation step: forward pass with KV cache, token selection, and cache update |
callWithCache |
callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor] |
Low-level cached forward pass returning [logits, hidden_states, updated_cache] |
buildCache |
private buildCache(tokenIds: Tensor): [Tensor, Tensor] |
Initializes the KV cache tensors for the prompt (prefill phase) |
makeGenerateFunction |
makeGenerateFunction(): GenerateFn |
Creates a callable generation function from the model |
Generation Flow
The autoregressive generation process involves these steps:
- Build cache:
buildCache(tokenIds)initializes KV cache by running the prompt through the backbone. - Generate step:
generateStep(inputs, endTokenId)performs one iteration:- Calls
callWithCache(tokenIds, cache, cacheUpdateIndex)for a forward pass. - Selects the next token from the output logits.
- Updates the cache with the new key-value pairs.
- Returns updated
NamedTensorMapwith the new token appended.
- Calls
- Repeat: Steps are repeated until the end token is generated or
maxLengthis reached.
I/O
- Inputs: Preprocessed token IDs as
Tensor, maximum generation length - Outputs: Generated token IDs as
Tensor; decode to text viatokenizer.detokenize()
callWithCache Details
| Parameter | Type | Description |
|---|---|---|
tokenIds |
Tensor |
Token IDs for the current generation step (single token during generation, full prompt during prefill) |
cache |
Tensor |
The KV cache tensor from previous steps |
cacheUpdateIndex |
number |
The position index at which to write the new KV entries in the cache |
Returns a tuple of:
- Logits:
Tensorof shape [batch, 1, vocab_size] (next-token logits) - Hidden states:
Tensorof the backbone's output - Updated cache:
Tensorwith new KV entries written atcacheUpdateIndex
Example
// Generate text (conceptual - base class throws NotImplementedError)
const inputTokens = preprocessor.call(tf.tensor1d(['The future of AI'], 'string'));
const generated = causalLM.generate(inputTokens, 50);
const text = tokenizer.detokenize(generated);
Implements
Principle:Tensorflow_Tfjs_Autoregressive_Text_Generation
Environment:Tensorflow_Tfjs_Browser_Runtime Heuristic:Tensorflow_Tfjs_Memory_Management_With_Tidy
Domains
Sources
Related Pages
Environments
- Environment:Tensorflow_Tfjs_Browser_Runtime -- Browser runtime (WebGL / WebGPU / WASM / CPU backends)
Heuristics
- Heuristic:Tensorflow_Tfjs_Memory_Management_With_Tidy -- Wrap predictions in tf.tidy() to prevent memory leaks