Principle:Fastai Fastbook Text Classifier Training
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Text Classification, Transfer Learning |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
Text classifier training is the final stage of the ULMFiT pipeline, where a classification head is attached to a fine-tuned language model encoder and the entire model is trained on labeled data using gradual unfreezing to achieve high classification accuracy while preventing catastrophic forgetting.
Description
This is Stage 3 of the ULMFiT three-stage approach. The classifier is built by:
- Loading the fine-tuned encoder: The LSTM encoder weights saved from the language model fine-tuning stage are loaded into the classifier. This provides the classifier with rich, domain-adapted text representations.
- Adding a classification head: A pooling and linear classification layer is added on top of the encoder. The classification head uses a concat pooling mechanism that concatenates the final hidden state, the max-pooled hidden states across all time steps, and the mean-pooled hidden states. This gives the classifier access to both the final context and global document-level information.
- Training with gradual unfreezing: Rather than training all layers at once (which risks catastrophic forgetting), layers are unfrozen progressively from the last layer to the first, with each group trained for one or more epochs before the next group is unfrozen.
Usage
Use text classifier training when:
- You have completed the language model fine-tuning stage and have a saved encoder.
- You have labeled training data for your classification task.
- You want to achieve state-of-the-art text classification accuracy using the full ULMFiT transfer learning pipeline.
Theoretical Basis
Classifier Architecture
The text classifier consists of the encoder (from the fine-tuned LM) plus a classification head:
Classifier Architecture:
Input tokens: [t_1, t_2, ..., t_n]
|
Embedding Layer (400 dim)
|
LSTM Layer 1 (1150 hidden)
|
LSTM Layer 2 (1150 hidden)
|
LSTM Layer 3 (1150 hidden)
| ^--- Encoder (loaded from LM) ---^
Concat Pooling
|
[h_n ; max_pool(H) ; mean_pool(H)] -> 3 * 1150 = 3450 dim
|
Linear (3450 -> 50) + BatchNorm + ReLU + Dropout
|
Linear (50 -> num_classes)
|
Softmax
|
Output: class probabilities
Concat Pooling
Standard sequence classification takes only the final hidden state h_n as the document representation. This discards information from earlier time steps. Concat pooling addresses this by combining three signals:
Given hidden states H = [h_1, h_2, ..., h_n] from the final LSTM layer:
h_final = h_n # Last hidden state (recent context)
h_max = max_pool(H, dim=time) # Maximum activation per dimension
h_mean = mean_pool(H, dim=time) # Average activation per dimension
pooled = concatenate(h_final, h_max, h_mean)
# Shape: 3 * hidden_dim = 3 * 1150 = 3450
This captures the final summary, the strongest signals, and the overall average across the entire document.
Gradual Unfreezing
Gradual unfreezing is the key technique that prevents catastrophic forgetting during classifier training. The layers are organized into groups and unfrozen one at a time:
FUNCTION gradual_unfreeze_training(classifier, data, epochs_per_stage):
# All encoder layers start frozen; only the classification head is trainable
# Stage 1: Train only the classification head
classifier.freeze()
classifier.fit_one_cycle(epochs_per_stage[0], lr=2e-2)
# Stage 2: Unfreeze last LSTM layer + head
classifier.freeze_to(-2)
classifier.fit_one_cycle(epochs_per_stage[1], lr=slice(1e-2/(2.6**4), 1e-2))
# Stage 3: Unfreeze last two LSTM layers + head
classifier.freeze_to(-3)
classifier.fit_one_cycle(epochs_per_stage[2], lr=slice(5e-3/(2.6**4), 5e-3))
# Stage 4: Unfreeze all layers
classifier.unfreeze()
classifier.fit_one_cycle(epochs_per_stage[3], lr=slice(1e-3/(2.6**4), 1e-3))
Discriminative Learning Rates with Slicing
When multiple layer groups are trainable, each group receives a different learning rate. The slice(low, high) notation in fastai distributes learning rates geometrically across layer groups:
Given slice(lr_low, lr_high) with N layer groups:
lr_group_0 = lr_low # Earliest layer (most general)
lr_group_1 = lr_low * r
lr_group_2 = lr_low * r^2
...
lr_group_(N-1) = lr_high # Latest layer (most task-specific)
where r = (lr_high / lr_low)^(1/(N-1))
This ensures that earlier layers (which learned general language features) change slowly, while later layers (which need to adapt to the classification task) change more quickly.
Expected Performance
On the IMDb sentiment classification benchmark:
| Stage | Layers Unfrozen | Typical Accuracy |
|---|---|---|
| Head only (1 epoch) | Classification head | ~87-89% |
| Last LSTM unfrozen (1 epoch) | Head + LSTM-3 | ~91-92% |
| Last two LSTMs unfrozen (1 epoch) | Head + LSTM-3 + LSTM-2 | ~93% |
| All layers unfrozen (2 epochs) | All layers | ~94-95% |
The ULMFiT paper reported 95.4% accuracy on IMDb, which was state-of-the-art at the time of publication.