Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Speechbrain Speechbrain CTC Training Loop

From Leeroopedia


Field Value
Principle Name CTC_Training_Loop
Description Connectionist Temporal Classification training with automatic alignment learning
Domains ASR, Sequence_Modeling, Deep_Learning
Knowledge Sources Graves et al. 2006 "Connectionist Temporal Classification"
Related Implementation Implementation:Speechbrain_Speechbrain_Brain_Fit_CTC

Overview

Connectionist Temporal Classification (CTC) is a training criterion that enables sequence-to-sequence model training without requiring explicit alignment between input and output sequences. This is essential for ASR, where the input audio frames and output text tokens have different lengths and no known frame-to-character correspondence. CTC introduces a special blank token and uses dynamic programming to marginalize over all possible valid alignments, allowing the model to learn the alignment implicitly during training.

Mathematical Foundation

The Alignment Problem

In ASR, the input is a sequence of acoustic frames x = (x_1, x_2, ..., x_T) of length T, and the target is a sequence of tokens y = (y_1, y_2, ..., y_U) of length U, where typically T >> U. The challenge is that we do not know which frames correspond to which tokens.

CTC Formulation

CTC solves this by defining a many-to-one mapping from frame-level paths to output sequences. A path is a sequence of labels (including a special blank token) of the same length as the input:

pi = (pi_1, pi_2, ..., pi_T)  where pi_t is in {blank, y_1, ..., y_U}

A collapsing function B(pi) removes repeated labels and blanks to produce the output sequence. For example:

B(a, a, blank, b, blank, c, c) = "abc"
B(blank, blank, a, a, b, b, b) = "ab"

The CTC loss marginalizes over all valid paths that collapse to the target:

p(y|x) = sum over all pi where B(pi)=y of: product from t=1 to T of p(pi_t | x)

The loss is the negative log-likelihood:

L_CTC = -log p(y|x)

Forward-Backward Algorithm

Direct enumeration of all valid paths is computationally intractable. CTC uses a dynamic programming algorithm (analogous to the forward-backward algorithm in HMMs) to efficiently compute the total probability and its gradient. This algorithm runs in O(T * U) time, making CTC training practical.

Training Loop Architecture

In SpeechBrain, the CTC training loop is orchestrated by Brain.fit(), which implements the following epoch-level cycle:

Per-Epoch Flow

for each epoch:
    1. on_stage_start(TRAIN, epoch)
    2. for each batch in train_set:
        a. compute_forward(batch, TRAIN)     -> predictions
        b. compute_objectives(predictions, batch, TRAIN) -> loss
        c. loss.backward()                   -> gradients
        d. gradient clipping (max_grad_norm)
        e. optimizer.step()                  -> update weights
    3. on_stage_end(TRAIN, avg_train_loss, epoch)
    4. on_stage_start(VALID, epoch)
    5. for each batch in valid_set:
        a. compute_forward(batch, VALID)     -> predictions
        b. compute_objectives(predictions, batch, VALID) -> loss + metrics
    6. on_stage_end(VALID, avg_valid_loss, epoch)
        - learning rate scheduling
        - checkpointing (save best by WER)

CTC-Specific Forward Pass

In the CTC ASR recipe, compute_forward() implements:

  1. Waveform augmentation (training only) -- speed perturbation, frequency dropping, time dropping
  2. Feature extraction -- wav2vec2 encoder processes raw waveforms into contextualized representations
  3. Encoder DNN -- a multi-layer network with batch normalization and dropout further processes wav2vec2 features
  4. CTC linear layer -- projects encoder output to token vocabulary size
  5. Log-softmax -- produces log-probabilities over the token vocabulary at each time step
  6. Decoding (validation/test only) -- greedy decoding during validation, beam search during testing

CTC-Specific Objective Computation

In compute_objectives():

  1. CTC loss is computed between the log-probability sequence and the target token sequence using speechbrain.nnet.losses.ctc_loss
  2. During validation and testing, decoded token sequences are converted back to words and WER/CER metrics are accumulated

Optimizer Configuration

The CTC ASR recipe uses a dual optimizer setup:

  • wav2vec2 optimizer (AdamW) -- with a small learning rate (1e-4) and weight decay, applied to the pretrained wav2vec2 parameters
  • Model optimizer (Adadelta) -- with a higher learning rate (1.0), applied to the encoder DNN and CTC linear layer

A warmup schedule is employed where the wav2vec2 optimizer is frozen for the first N steps (default 500), allowing the randomly initialized downstream layers to reach a reasonable state before fine-tuning the pretrained encoder begins.

Learning Rate Scheduling

The NewBobScheduler is used for both optimizers. It monitors the validation loss and reduces the learning rate by a factor (annealing_factor) when improvement falls below a threshold (improvement_threshold). The wav2vec2 and model schedulers can have different annealing factors to control their adaptation rates independently.

Checkpointing Strategy

After each validation epoch, the checkpoint is saved with the WER metric as metadata. The save_and_keep_only(min_keys=["WER"]) strategy keeps only the checkpoint with the lowest WER, ensuring that the best model is always preserved.

Mixed Precision Training

SpeechBrain supports mixed precision (fp16, bf16) training through PyTorch's native AMP (Automatic Mixed Precision). The precision setting in the YAML configuration controls this. Mixed precision reduces memory usage and increases throughput on supported hardware while maintaining training stability through loss scaling.

Related Concepts

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment