Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Fastai Fastbook Text Classifier Learner

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Text Classification, Transfer Learning
Last Updated 2026-02-09 17:00 GMT

Overview

Concrete tool for creating and training a text classification model that leverages a fine-tuned language model encoder with gradual unfreezing, provided by the fastai library.

Description

The text_classifier_learner function creates a Learner configured for text classification. It:

  • Instantiates the specified architecture (AWD_LSTM) with a classification head that uses concat pooling (concatenation of final hidden state, max-pooled hidden states, and mean-pooled hidden states).
  • Provides load_encoder to load the encoder weights saved during language model fine-tuning, ensuring the classifier starts with domain-adapted representations.
  • Exposes freeze_to for implementing the gradual unfreezing schedule described in the ULMFiT paper.
  • Supports discriminative learning rates via the slice(low, high) syntax in fit_one_cycle.

The classification head architecture is: concat pooling (3450 dim) followed by a linear layer (3450 to 50), batch normalization, ReLU, dropout, and a final linear layer (50 to number of classes).

Usage

Use text_classifier_learner as the final stage of the ULMFiT pipeline. It requires:

  1. Classifier DataLoaders (from DataBlock with TextBlock and CategoryBlock).
  2. A saved encoder file from the language model fine-tuning stage.

The standard training procedure involves loading the encoder, then performing gradual unfreezing over several stages.

Code Reference

Source Location

  • Repository: fastbook
  • File: translations/cn/10_nlp.md (lines 713-761)
  • Library module: fastai.text.learner

Signature

def text_classifier_learner(
    dls: DataLoaders,                     # Classification DataLoaders
    arch: callable = AWD_LSTM,            # Model architecture
    seq_len: int = 72,                    # Input sequence length
    config: dict = None,                  # Architecture config overrides
    backwards: bool = False,              # Use a backwards model
    pretrained: bool = True,              # Use pretrained backbone
    drop_mult: float = 0.5,              # Dropout multiplier for all dropout layers
    n_out: int = None,                    # Number of output classes (inferred from dls)
    lin_ftrs: list = None,                # Hidden sizes for classifier head linear layers
    ps: list = None,                      # Dropout probabilities for classifier head
    max_len: int = 72 * 20,              # Maximum sequence length (longer docs are truncated)
    loss_func: callable = None,           # Loss function (default: CrossEntropyLossFlat)
    opt_func: callable = Adam,            # Optimizer
    lr: float = 0.001,                    # Base learning rate
    splitter: callable = awd_lstm_clas_split,  # Parameter group splitter
    cbs: list = None,                     # Additional callbacks
    metrics: list = None,                 # Metrics to track
    path: Path = None,                    # Model save path
    model_dir: str = 'models',           # Subdirectory for saved models
    wd: float = None,                     # Weight decay
    moms: tuple = (0.95, 0.85, 0.95)     # Momentum schedule
) -> Learner

# Key Learner methods for classifier training:
Learner.load_encoder(name: str)
Learner.freeze()
Learner.freeze_to(n: int)
Learner.unfreeze()
Learner.fit_one_cycle(n_epoch, lr_max=None, div=25., div_final=1e5, pct_start=0.25)

Import

from fastai.text.all import text_classifier_learner, AWD_LSTM

I/O Contract

Inputs

Name Type Required Description
dls DataLoaders Yes Classification DataLoaders created by DataBlock with TextBlock and CategoryBlock.
arch callable No Model architecture class. Default: AWD_LSTM.
drop_mult float No Multiplier for all dropout rates. Default: 0.5. Higher values increase regularization.
metrics list No Metrics to track. Common choice: [accuracy].
name str Yes (for load_encoder) Name of the saved encoder file (without .pth extension). Must match the name used in save_encoder during LM fine-tuning.
n int Yes (for freeze_to) Layer group index to freeze up to. -2 freezes all but the last two groups; -3 freezes all but the last three.
n_epoch int Yes (for fit_one_cycle) Number of epochs to train.
lr_max float or slice No (for fit_one_cycle) Learning rate or learning rate range. Use slice(low, high) for discriminative rates across layer groups.

Outputs

Name Type Description
learn Learner A configured Learner for text classification with the AWD-LSTM encoder and concat pooling classification head.
predictions tensor Class probabilities of shape (n_samples, n_classes) when calling learn.predict or learn.get_preds.
training metrics dict Per-epoch training loss, validation loss, and accuracy logged during fit_one_cycle.

Usage Examples

Basic Usage

from fastai.text.all import *

path = untar_data(URLs.IMDB)

# Assume dls_lm was already created and LM was fine-tuned
# Create classifier DataLoaders
dls_clas = DataBlock(
    blocks=(TextBlock.from_folder(path, vocab=dls_lm.vocab), CategoryBlock),
    get_y=parent_label,
    splitter=GrandparentSplitter(valid_name='test')
).dataloaders(path, path=path, bs=128, seq_len=72)

# Create the text classifier learner
learn = text_classifier_learner(
    dls_clas,
    AWD_LSTM,
    drop_mult=0.5,
    metrics=[accuracy]
)

# Load the fine-tuned encoder
learn.load_encoder('finetuned')

# Train (simple approach - just unfreeze and train)
learn.fit_one_cycle(1, 2e-2)

Full Gradual Unfreezing Schedule

from fastai.text.all import *

path = untar_data(URLs.IMDB)

# Create classifier DataLoaders (assumes dls_lm.vocab is available)
dls_clas = DataBlock(
    blocks=(TextBlock.from_folder(path, vocab=dls_lm.vocab), CategoryBlock),
    get_y=parent_label,
    splitter=GrandparentSplitter(valid_name='test')
).dataloaders(path, path=path, bs=128, seq_len=72)

# Create learner and load encoder
learn = text_classifier_learner(
    dls_clas,
    AWD_LSTM,
    drop_mult=0.5,
    metrics=[accuracy]
)
learn.load_encoder('finetuned')

# Stage 1: Train only the classifier head (all encoder layers frozen)
learn.fit_one_cycle(1, 2e-2)
# Expected: ~87-89% accuracy

# Stage 2: Unfreeze last layer group
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(1e-2/(2.6**4), 1e-2))
# Expected: ~91-92% accuracy

# Stage 3: Unfreeze one more layer group
learn.freeze_to(-3)
learn.fit_one_cycle(1, slice(5e-3/(2.6**4), 5e-3))
# Expected: ~93% accuracy

# Stage 4: Unfreeze all layers
learn.unfreeze()
learn.fit_one_cycle(2, slice(1e-3/(2.6**4), 1e-3))
# Expected: ~94-95% accuracy

Making Predictions

from fastai.text.all import *

# After training, make predictions on new text
prediction = learn.predict("This movie was absolutely wonderful! Great acting and plot.")
print(prediction)
# Output: ('pos', TensorText(1), TensorText([0.0432, 0.9568]))
# (predicted_label, label_index, class_probabilities)

# Predict on another review
prediction = learn.predict("Terrible film. Waste of time. Do not watch.")
print(prediction)
# Output: ('neg', TensorText(0), TensorText([0.9821, 0.0179]))

Getting Predictions on the Full Validation Set

from fastai.text.all import *

# Get predictions and targets for the full validation set
preds, targets = learn.get_preds(dl=dls_clas.valid)
print(f"Predictions shape: {preds.shape}")
# Output: torch.Size([25000, 2])

print(f"Targets shape: {targets.shape}")
# Output: torch.Size([25000])

# Calculate accuracy manually
predicted_classes = preds.argmax(dim=1)
accuracy = (predicted_classes == targets).float().mean()
print(f"Validation accuracy: {accuracy:.4f}")
# Output: ~0.94

Complete End-to-End Pipeline

from fastai.text.all import *

# ---- Stage 1: Data Preparation ----
path = untar_data(URLs.IMDB)

# ---- Stage 2: Language Model Fine-tuning ----
dls_lm = DataBlock(
    blocks=TextBlock.from_folder(path, is_lm=True),
    get_items=get_text_files,
    splitter=RandomSplitter(0.1)
).dataloaders(path, path=path, bs=128, seq_len=80)

learn_lm = language_model_learner(
    dls_lm, AWD_LSTM, drop_mult=0.3,
    metrics=[accuracy, Perplexity()]
)
learn_lm.fit_one_cycle(1, 2e-2)
learn_lm.unfreeze()
learn_lm.fit_one_cycle(10, 2e-3)
learn_lm.save_encoder('finetuned')

# ---- Stage 3: Classifier Training ----
dls_clas = DataBlock(
    blocks=(TextBlock.from_folder(path, vocab=dls_lm.vocab), CategoryBlock),
    get_y=parent_label,
    splitter=GrandparentSplitter(valid_name='test')
).dataloaders(path, path=path, bs=128, seq_len=72)

learn_clas = text_classifier_learner(
    dls_clas, AWD_LSTM, drop_mult=0.5,
    metrics=[accuracy]
)
learn_clas.load_encoder('finetuned')

# Gradual unfreezing
learn_clas.fit_one_cycle(1, 2e-2)
learn_clas.freeze_to(-2)
learn_clas.fit_one_cycle(1, slice(1e-2/(2.6**4), 1e-2))
learn_clas.freeze_to(-3)
learn_clas.fit_one_cycle(1, slice(5e-3/(2.6**4), 5e-3))
learn_clas.unfreeze()
learn_clas.fit_one_cycle(2, slice(1e-3/(2.6**4), 1e-3))

# Final accuracy: ~94%

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment