Implementation:Fastai Fastbook Text Classifier Learner
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Text Classification, Transfer Learning |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
Concrete tool for creating and training a text classification model that leverages a fine-tuned language model encoder with gradual unfreezing, provided by the fastai library.
Description
The text_classifier_learner function creates a Learner configured for text classification. It:
- Instantiates the specified architecture (AWD_LSTM) with a classification head that uses concat pooling (concatenation of final hidden state, max-pooled hidden states, and mean-pooled hidden states).
- Provides load_encoder to load the encoder weights saved during language model fine-tuning, ensuring the classifier starts with domain-adapted representations.
- Exposes freeze_to for implementing the gradual unfreezing schedule described in the ULMFiT paper.
- Supports discriminative learning rates via the slice(low, high) syntax in fit_one_cycle.
The classification head architecture is: concat pooling (3450 dim) followed by a linear layer (3450 to 50), batch normalization, ReLU, dropout, and a final linear layer (50 to number of classes).
Usage
Use text_classifier_learner as the final stage of the ULMFiT pipeline. It requires:
- Classifier DataLoaders (from DataBlock with TextBlock and CategoryBlock).
- A saved encoder file from the language model fine-tuning stage.
The standard training procedure involves loading the encoder, then performing gradual unfreezing over several stages.
Code Reference
Source Location
- Repository: fastbook
- File: translations/cn/10_nlp.md (lines 713-761)
- Library module: fastai.text.learner
Signature
def text_classifier_learner(
dls: DataLoaders, # Classification DataLoaders
arch: callable = AWD_LSTM, # Model architecture
seq_len: int = 72, # Input sequence length
config: dict = None, # Architecture config overrides
backwards: bool = False, # Use a backwards model
pretrained: bool = True, # Use pretrained backbone
drop_mult: float = 0.5, # Dropout multiplier for all dropout layers
n_out: int = None, # Number of output classes (inferred from dls)
lin_ftrs: list = None, # Hidden sizes for classifier head linear layers
ps: list = None, # Dropout probabilities for classifier head
max_len: int = 72 * 20, # Maximum sequence length (longer docs are truncated)
loss_func: callable = None, # Loss function (default: CrossEntropyLossFlat)
opt_func: callable = Adam, # Optimizer
lr: float = 0.001, # Base learning rate
splitter: callable = awd_lstm_clas_split, # Parameter group splitter
cbs: list = None, # Additional callbacks
metrics: list = None, # Metrics to track
path: Path = None, # Model save path
model_dir: str = 'models', # Subdirectory for saved models
wd: float = None, # Weight decay
moms: tuple = (0.95, 0.85, 0.95) # Momentum schedule
) -> Learner
# Key Learner methods for classifier training:
Learner.load_encoder(name: str)
Learner.freeze()
Learner.freeze_to(n: int)
Learner.unfreeze()
Learner.fit_one_cycle(n_epoch, lr_max=None, div=25., div_final=1e5, pct_start=0.25)
Import
from fastai.text.all import text_classifier_learner, AWD_LSTM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dls | DataLoaders | Yes | Classification DataLoaders created by DataBlock with TextBlock and CategoryBlock. |
| arch | callable | No | Model architecture class. Default: AWD_LSTM. |
| drop_mult | float | No | Multiplier for all dropout rates. Default: 0.5. Higher values increase regularization. |
| metrics | list | No | Metrics to track. Common choice: [accuracy]. |
| name | str | Yes (for load_encoder) | Name of the saved encoder file (without .pth extension). Must match the name used in save_encoder during LM fine-tuning. |
| n | int | Yes (for freeze_to) | Layer group index to freeze up to. -2 freezes all but the last two groups; -3 freezes all but the last three. |
| n_epoch | int | Yes (for fit_one_cycle) | Number of epochs to train. |
| lr_max | float or slice | No (for fit_one_cycle) | Learning rate or learning rate range. Use slice(low, high) for discriminative rates across layer groups. |
Outputs
| Name | Type | Description |
|---|---|---|
| learn | Learner | A configured Learner for text classification with the AWD-LSTM encoder and concat pooling classification head. |
| predictions | tensor | Class probabilities of shape (n_samples, n_classes) when calling learn.predict or learn.get_preds. |
| training metrics | dict | Per-epoch training loss, validation loss, and accuracy logged during fit_one_cycle. |
Usage Examples
Basic Usage
from fastai.text.all import *
path = untar_data(URLs.IMDB)
# Assume dls_lm was already created and LM was fine-tuned
# Create classifier DataLoaders
dls_clas = DataBlock(
blocks=(TextBlock.from_folder(path, vocab=dls_lm.vocab), CategoryBlock),
get_y=parent_label,
splitter=GrandparentSplitter(valid_name='test')
).dataloaders(path, path=path, bs=128, seq_len=72)
# Create the text classifier learner
learn = text_classifier_learner(
dls_clas,
AWD_LSTM,
drop_mult=0.5,
metrics=[accuracy]
)
# Load the fine-tuned encoder
learn.load_encoder('finetuned')
# Train (simple approach - just unfreeze and train)
learn.fit_one_cycle(1, 2e-2)
Full Gradual Unfreezing Schedule
from fastai.text.all import *
path = untar_data(URLs.IMDB)
# Create classifier DataLoaders (assumes dls_lm.vocab is available)
dls_clas = DataBlock(
blocks=(TextBlock.from_folder(path, vocab=dls_lm.vocab), CategoryBlock),
get_y=parent_label,
splitter=GrandparentSplitter(valid_name='test')
).dataloaders(path, path=path, bs=128, seq_len=72)
# Create learner and load encoder
learn = text_classifier_learner(
dls_clas,
AWD_LSTM,
drop_mult=0.5,
metrics=[accuracy]
)
learn.load_encoder('finetuned')
# Stage 1: Train only the classifier head (all encoder layers frozen)
learn.fit_one_cycle(1, 2e-2)
# Expected: ~87-89% accuracy
# Stage 2: Unfreeze last layer group
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(1e-2/(2.6**4), 1e-2))
# Expected: ~91-92% accuracy
# Stage 3: Unfreeze one more layer group
learn.freeze_to(-3)
learn.fit_one_cycle(1, slice(5e-3/(2.6**4), 5e-3))
# Expected: ~93% accuracy
# Stage 4: Unfreeze all layers
learn.unfreeze()
learn.fit_one_cycle(2, slice(1e-3/(2.6**4), 1e-3))
# Expected: ~94-95% accuracy
Making Predictions
from fastai.text.all import *
# After training, make predictions on new text
prediction = learn.predict("This movie was absolutely wonderful! Great acting and plot.")
print(prediction)
# Output: ('pos', TensorText(1), TensorText([0.0432, 0.9568]))
# (predicted_label, label_index, class_probabilities)
# Predict on another review
prediction = learn.predict("Terrible film. Waste of time. Do not watch.")
print(prediction)
# Output: ('neg', TensorText(0), TensorText([0.9821, 0.0179]))
Getting Predictions on the Full Validation Set
from fastai.text.all import *
# Get predictions and targets for the full validation set
preds, targets = learn.get_preds(dl=dls_clas.valid)
print(f"Predictions shape: {preds.shape}")
# Output: torch.Size([25000, 2])
print(f"Targets shape: {targets.shape}")
# Output: torch.Size([25000])
# Calculate accuracy manually
predicted_classes = preds.argmax(dim=1)
accuracy = (predicted_classes == targets).float().mean()
print(f"Validation accuracy: {accuracy:.4f}")
# Output: ~0.94
Complete End-to-End Pipeline
from fastai.text.all import *
# ---- Stage 1: Data Preparation ----
path = untar_data(URLs.IMDB)
# ---- Stage 2: Language Model Fine-tuning ----
dls_lm = DataBlock(
blocks=TextBlock.from_folder(path, is_lm=True),
get_items=get_text_files,
splitter=RandomSplitter(0.1)
).dataloaders(path, path=path, bs=128, seq_len=80)
learn_lm = language_model_learner(
dls_lm, AWD_LSTM, drop_mult=0.3,
metrics=[accuracy, Perplexity()]
)
learn_lm.fit_one_cycle(1, 2e-2)
learn_lm.unfreeze()
learn_lm.fit_one_cycle(10, 2e-3)
learn_lm.save_encoder('finetuned')
# ---- Stage 3: Classifier Training ----
dls_clas = DataBlock(
blocks=(TextBlock.from_folder(path, vocab=dls_lm.vocab), CategoryBlock),
get_y=parent_label,
splitter=GrandparentSplitter(valid_name='test')
).dataloaders(path, path=path, bs=128, seq_len=72)
learn_clas = text_classifier_learner(
dls_clas, AWD_LSTM, drop_mult=0.5,
metrics=[accuracy]
)
learn_clas.load_encoder('finetuned')
# Gradual unfreezing
learn_clas.fit_one_cycle(1, 2e-2)
learn_clas.freeze_to(-2)
learn_clas.fit_one_cycle(1, slice(1e-2/(2.6**4), 1e-2))
learn_clas.freeze_to(-3)
learn_clas.fit_one_cycle(1, slice(5e-3/(2.6**4), 5e-3))
learn_clas.unfreeze()
learn_clas.fit_one_cycle(2, slice(1e-3/(2.6**4), 1e-3))
# Final accuracy: ~94%
Related Pages
Implements Principle
Requires Environment
- Environment:Fastai_Fastbook_Python_FastAI_Environment
- Environment:Fastai_Fastbook_CUDA_GPU_Environment
- Environment:Fastai_Fastbook_NLP_SpaCy_Environment