Implementation:Tensorflow Tfjs MultiheadAttention Test
| Knowledge Sources | |
|---|---|
| Domains | Testing, Layers_API |
| Last Updated | 2026-02-10 06:00 GMT |
Overview
This test suite validates the MultiHeadAttention layer, a core component of transformer architectures. The tests cover non-masked attention with various value dimensions and output shapes, self-attention, attention score retrieval, masked attention (causal masks, padding masks), dropout during training, output shape customization, serialization/deserialization, and high-dimensional attention inputs. The layer implements scaled dot-product attention with multiple parallel attention heads.
Code Reference
Source Location: tfjs-layers/src/layers/nlp/multihead_attention_test.ts (518 lines)
Repository: GitHub
Test Describe Blocks
MultiHeadAttention- Main test block including:Non Masked Attention- Key/value same projection and different projection- Non-masked self attention (single input)
- Attention scores retrieval (
callAndReturnAttentionScores) - Attention scores with separate values
- Masked attention (causal mask, query mask, value mask)
- Dropout during training
- Output shape customization
- Serialization round-trip (getConfig, fromConfig)
- Memory leak verification
High Dimensional Attention- Attention with higher-dimensional inputsAttentionSubclass- Subclassing and embedding integration
I/O Contract
Inputs to tests:
- Query tensors: shape
[batch, queryLength, features](e.g.,[1, 40, 80]) - Value tensors: shape
[batch, valueLength, features](e.g.,[1, 20, 80]) - Optional key tensors (defaults to value)
- Attention masks: boolean tensors for masking specific positions
- Configuration: numHeads, keyDim, valueDim, outputShape, dropout
Expected outputs/assertions:
- Output shape:
[batch, queryLength, outputDim](e.g.,[1, 40, 80]) - Attention coefficient shape:
[batch, numHeads, queryLength, valueLength](e.g.,[1, 12, 40, 40]) - Masked positions produce different outputs than unmasked
- Dropout with rate 1.0 zeros out attention, rate 0.0 produces identical results
- Serialization preserves all configuration parameters
Usage Example
describeMathCPUAndGPU('MultiHeadAttention', () => {
it('non masked self attention', () => {
const testLayer = new MultiHeadAttention({numHeads: 12, keyDim: 64});
const query = input({shape: [40, 80]});
const output = testLayer.apply(query, {value: query}) as Tensor;
expect(output.shape).toEqual([null, 40, 80]);
});
it('attention scores', () => {
const testLayer = new MultiHeadAttention({numHeads: 12, keyDim: 64});
const query = ones([1, 40, 80]);
const [output, coef] =
testLayer.callAndReturnAttentionScores(query, {value: query});
expect(output.shape).toEqual([1, 40, 80]);
expect(coef.shape).toEqual([1, 12, 40, 40]);
});
});
Test Coverage Summary
| Category | Count | Details |
|---|---|---|
| Non-masked Attention | 4+ | Same/different projections, self-attention |
| Attention Scores | 3+ | Score retrieval with query, key, value |
| Masked Attention | 5+ | Causal mask, query mask, value mask |
| Dropout | 2+ | Rate 0.0 and 1.0 verification |
| Serialization | 2+ | Config round-trip |
| Memory | 1+ | Leak verification |
| High Dimensional | 2+ | Higher-rank input tensors |
| Test Environment | Mixed | CPU, GPU |