Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Brain Fit CTC

From Leeroopedia


Field Value
Implementation Name Brain_Fit_CTC
API Signature Brain.fit(self, epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})
Source File speechbrain/core.py:L1488-1567 (fit method). Recipe: recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py:L43 (compute_forward), L71 (compute_objectives)
Import from speechbrain.core import Brain
Type API Doc
Related Principle Principle:Speechbrain_Speechbrain_CTC_Training_Loop

Description

Brain.fit() is the main training method that orchestrates the complete training and validation loop. It iterates over epochs using the provided epoch_counter, calling fit_batch() for each training batch and evaluate_batch() for each validation batch. For CTC ASR training, the subclass's compute_forward() and compute_objectives() methods implement the CTC-specific forward pass and loss computation.

Inputs

Parameter Type Default Description
epoch_counter iterable (required) An iterable that yields epoch numbers. Typically an EpochCounter instance from YAML configuration that also handles epoch limit and resumption.
train_set Dataset or DataLoader (required) Training data. If a DynamicItemDataset is provided, a DataLoader is automatically created using train_loader_kwargs.
valid_set Dataset or DataLoader None Validation data. If a DynamicItemDataset is provided, a DataLoader is automatically created using valid_loader_kwargs.
progressbar bool None Whether to display progress bars. If None, determined by the noprogressbar run option.
train_loader_kwargs dict {} Keyword arguments passed to make_dataloader() for the training DataLoader. Common keys: batch_size, num_workers, shuffle, batch_sampler.
valid_loader_kwargs dict {} Keyword arguments passed to make_dataloader() for the validation DataLoader.

Outputs

The fit() method does not return a value. Its effects are:

  • Trained model weights -- all registered modules are updated through gradient descent
  • Checkpoints -- saved to the checkpointer directory, with the best model selected by WER
  • Training logs -- loss, WER, CER, and learning rate values logged per epoch
  • Updated schedulers -- learning rate schedulers are stepped based on validation loss

Execution Flow

fit()
  |
  +-- Convert datasets to DataLoaders if needed
  +-- on_fit_start()           # Initialize optimizers, recover from checkpoint
  |
  +-- for each epoch in epoch_counter:
  |     |
  |     +-- _fit_train(train_set, epoch)
  |     |     |
  |     |     +-- on_stage_start(TRAIN, epoch)
  |     |     +-- modules.train()
  |     |     +-- for each batch:
  |     |     |     +-- fit_batch(batch, TRAIN)
  |     |     |           +-- compute_forward(batch, TRAIN)   -> p_ctc, wav_lens, None
  |     |     |           +-- compute_objectives(preds, batch, TRAIN) -> CTC loss
  |     |     |           +-- loss.backward()
  |     |     |           +-- gradient clipping (max_grad_norm=5.0)
  |     |     |           +-- optimizer.step()
  |     |     +-- on_stage_end(TRAIN, avg_loss, epoch)
  |     |
  |     +-- _fit_valid(valid_set, epoch)
  |           |
  |           +-- on_stage_start(VALID, epoch)  # Initialize WER/CER metrics
  |           +-- modules.eval()
  |           +-- torch.no_grad()
  |           +-- for each batch:
  |           |     +-- evaluate_batch(batch, VALID)
  |           |           +-- compute_forward(batch, VALID)   -> p_ctc, wav_lens, p_tokens
  |           |           +-- compute_objectives(preds, batch, VALID) -> CTC loss + WER/CER
  |           +-- on_stage_end(VALID, avg_loss, epoch)
  |                 +-- LR scheduling (NewBobScheduler)
  |                 +-- Logging (loss, WER, CER, LR)
  |                 +-- Checkpointing (save best by WER)

CTC-Specific compute_forward

The ASR subclass implements compute_forward() for the CTC pipeline:

def compute_forward(self, batch, stage):
    """Forward: waveform -> wav2vec2 -> encoder DNN -> CTC logits."""
    batch = batch.to(self.device)
    wavs, wav_lens = batch.sig
    wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

    # Data augmentation (training only)
    if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
        wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)

    # Feature extraction and encoding
    feats = self.modules.wav2vec2(wavs, wav_lens)   # Pretrained features
    x = self.modules.enc(feats)                      # Encoder DNN
    logits = self.modules.ctc_lin(x)                 # CTC output projection
    p_ctc = self.hparams.log_softmax(logits)         # Log-probabilities

    # Decoding for metrics (not during training)
    p_tokens = None
    if stage == sb.Stage.VALID:
        p_tokens = sb.decoders.ctc_greedy_decode(
            p_ctc, wav_lens, blank_id=self.hparams.blank_index
        )
    elif stage == sb.Stage.TEST:
        p_tokens = test_searcher(p_ctc, wav_lens)

    return p_ctc, wav_lens, p_tokens

CTC-Specific compute_objectives

def compute_objectives(self, predictions, batch, stage):
    """Compute CTC loss and track error metrics."""
    p_ctc, wav_lens, p_tokens = predictions
    ids = batch.id
    tokens, tokens_lens = batch.tokens

    # Replicate labels for augmented samples
    if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
        tokens = self.hparams.wav_augment.replicate_labels(tokens)
        tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens)

    # CTC loss computation
    loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

    # WER/CER tracking (validation and test only)
    if stage != sb.Stage.TRAIN:
        if stage == sb.Stage.VALID:
            predicted_words = self.tokenizer(
                p_tokens, task="decode_from_list"
            )
        elif stage == sb.Stage.TEST:
            predicted_words = [hyp[0].text.split(" ") for hyp in p_tokens]

        target_words = undo_padding(tokens, tokens_lens)
        target_words = self.tokenizer(target_words, task="decode_from_list")

        self.wer_metric.append(ids, predicted_words, target_words)
        self.cer_metric.append(ids, predicted_words, target_words)

    return loss

Epoch-End Callbacks

on_stage_start (VALID)

def on_stage_start(self, stage, epoch):
    if stage != sb.Stage.TRAIN:
        self.cer_metric = self.hparams.cer_computer()
        self.wer_metric = self.hparams.error_rate_computer()

on_stage_end (VALID)

def on_stage_end(self, stage, stage_loss, epoch):
    stage_stats = {"loss": stage_loss}
    if stage == sb.Stage.TRAIN:
        self.train_stats = stage_stats
    else:
        stage_stats["CER"] = self.cer_metric.summarize("error_rate")
        stage_stats["WER"] = self.wer_metric.summarize("error_rate")

    if stage == sb.Stage.VALID:
        # Learning rate annealing
        old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
            stage_stats["loss"]
        )
        old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
            stage_stats["loss"]
        )
        sb.nnet.schedulers.update_learning_rate(
            self.model_optimizer, new_lr_model
        )

        # Logging
        self.hparams.train_logger.log_stats(
            stats_meta={"epoch": epoch, "lr_model": old_lr_model},
            train_stats=self.train_stats,
            valid_stats=stage_stats,
        )

        # Save checkpoint (keep only the best by WER)
        self.checkpointer.save_and_keep_only(
            meta={"WER": stage_stats["WER"]},
            min_keys=["WER"],
        )

Dual Optimizer Setup

The CTC recipe overrides init_optimizers() to create separate optimizers:

def init_optimizers(self):
    # Wav2vec2 optimizer (only if not frozen)
    if not self.hparams.wav2vec2.freeze:
        self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
            self.modules.wav2vec2.parameters()
        )

    # Model optimizer (encoder DNN + CTC linear)
    self.model_optimizer = self.hparams.model_opt_class(
        self.hparams.model.parameters()
    )

    # Register with checkpointer for resumption
    if self.checkpointer is not None:
        self.checkpointer.add_recoverable("modelopt", self.model_optimizer)

    self.optimizers_dict = {
        "model_optimizer": self.model_optimizer,
    }
    if not self.hparams.wav2vec2.freeze:
        self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer

Warmup and Freezing

def freeze_optimizers(self, optimizers):
    """Freeze wav2vec2 optimizer during warmup phase."""
    valid_optimizers = {}
    if not self.hparams.wav2vec2.freeze:
        if self.optimizer_step >= self.hparams.warmup_steps:
            valid_optimizers["wav2vec_optimizer"] = optimizers[
                "wav2vec_optimizer"
            ]
    valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
    return valid_optimizers

Usage Example

# Complete training invocation from the recipe
asr_brain = ASR(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = tokenizer

# Training
asr_brain.fit(
    asr_brain.hparams.epoch_counter,
    train_data,
    valid_data,
    train_loader_kwargs=train_dataloader_opts,
    valid_loader_kwargs=valid_dataloader_opts,
)

Key YAML Configuration Values

Key Typical Value Description
number_of_epochs 30 Maximum epochs to train
optimizer_step_limit 75000 Maximum optimizer steps (early stopping)
lr 1.0 Model optimizer learning rate (Adadelta)
lr_wav2vec 0.0001 Wav2vec2 optimizer learning rate (AdamW)
warmup_steps 500 Steps before wav2vec2 optimizer is activated
precision "fp16" Mixed precision mode
dynamic_batching True Use duration-based dynamic batching

Dependencies

  • speechbrain.nnet.losses.ctc_loss -- CTC loss function wrapping PyTorch's torch.nn.functional.ctc_loss
  • speechbrain.decoders.ctc_greedy_decode -- greedy CTC decoding for validation
  • speechbrain.decoders.ctc.CTCBeamSearcher -- beam search decoding for testing
  • speechbrain.nnet.schedulers.NewBobScheduler -- learning rate annealing
  • speechbrain.utils.metric_stats.ErrorRateStats -- WER/CER metric accumulation

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment