Implementation:Speechbrain Speechbrain Brain Fit CTC
| Field | Value |
|---|---|
| Implementation Name | Brain_Fit_CTC |
| API Signature | Brain.fit(self, epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})
|
| Source File | speechbrain/core.py:L1488-1567 (fit method). Recipe: recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py:L43 (compute_forward), L71 (compute_objectives) |
| Import | from speechbrain.core import Brain
|
| Type | API Doc |
| Related Principle | Principle:Speechbrain_Speechbrain_CTC_Training_Loop |
Description
Brain.fit() is the main training method that orchestrates the complete training and validation loop. It iterates over epochs using the provided epoch_counter, calling fit_batch() for each training batch and evaluate_batch() for each validation batch. For CTC ASR training, the subclass's compute_forward() and compute_objectives() methods implement the CTC-specific forward pass and loss computation.
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
epoch_counter |
iterable | (required) | An iterable that yields epoch numbers. Typically an EpochCounter instance from YAML configuration that also handles epoch limit and resumption.
|
train_set |
Dataset or DataLoader | (required) | Training data. If a DynamicItemDataset is provided, a DataLoader is automatically created using train_loader_kwargs.
|
valid_set |
Dataset or DataLoader | None | Validation data. If a DynamicItemDataset is provided, a DataLoader is automatically created using valid_loader_kwargs.
|
progressbar |
bool | None | Whether to display progress bars. If None, determined by the noprogressbar run option.
|
train_loader_kwargs |
dict | {} | Keyword arguments passed to make_dataloader() for the training DataLoader. Common keys: batch_size, num_workers, shuffle, batch_sampler.
|
valid_loader_kwargs |
dict | {} | Keyword arguments passed to make_dataloader() for the validation DataLoader.
|
Outputs
The fit() method does not return a value. Its effects are:
- Trained model weights -- all registered modules are updated through gradient descent
- Checkpoints -- saved to the checkpointer directory, with the best model selected by WER
- Training logs -- loss, WER, CER, and learning rate values logged per epoch
- Updated schedulers -- learning rate schedulers are stepped based on validation loss
Execution Flow
fit()
|
+-- Convert datasets to DataLoaders if needed
+-- on_fit_start() # Initialize optimizers, recover from checkpoint
|
+-- for each epoch in epoch_counter:
| |
| +-- _fit_train(train_set, epoch)
| | |
| | +-- on_stage_start(TRAIN, epoch)
| | +-- modules.train()
| | +-- for each batch:
| | | +-- fit_batch(batch, TRAIN)
| | | +-- compute_forward(batch, TRAIN) -> p_ctc, wav_lens, None
| | | +-- compute_objectives(preds, batch, TRAIN) -> CTC loss
| | | +-- loss.backward()
| | | +-- gradient clipping (max_grad_norm=5.0)
| | | +-- optimizer.step()
| | +-- on_stage_end(TRAIN, avg_loss, epoch)
| |
| +-- _fit_valid(valid_set, epoch)
| |
| +-- on_stage_start(VALID, epoch) # Initialize WER/CER metrics
| +-- modules.eval()
| +-- torch.no_grad()
| +-- for each batch:
| | +-- evaluate_batch(batch, VALID)
| | +-- compute_forward(batch, VALID) -> p_ctc, wav_lens, p_tokens
| | +-- compute_objectives(preds, batch, VALID) -> CTC loss + WER/CER
| +-- on_stage_end(VALID, avg_loss, epoch)
| +-- LR scheduling (NewBobScheduler)
| +-- Logging (loss, WER, CER, LR)
| +-- Checkpointing (save best by WER)
CTC-Specific compute_forward
The ASR subclass implements compute_forward() for the CTC pipeline:
def compute_forward(self, batch, stage):
"""Forward: waveform -> wav2vec2 -> encoder DNN -> CTC logits."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
# Data augmentation (training only)
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
# Feature extraction and encoding
feats = self.modules.wav2vec2(wavs, wav_lens) # Pretrained features
x = self.modules.enc(feats) # Encoder DNN
logits = self.modules.ctc_lin(x) # CTC output projection
p_ctc = self.hparams.log_softmax(logits) # Log-probabilities
# Decoding for metrics (not during training)
p_tokens = None
if stage == sb.Stage.VALID:
p_tokens = sb.decoders.ctc_greedy_decode(
p_ctc, wav_lens, blank_id=self.hparams.blank_index
)
elif stage == sb.Stage.TEST:
p_tokens = test_searcher(p_ctc, wav_lens)
return p_ctc, wav_lens, p_tokens
CTC-Specific compute_objectives
def compute_objectives(self, predictions, batch, stage):
"""Compute CTC loss and track error metrics."""
p_ctc, wav_lens, p_tokens = predictions
ids = batch.id
tokens, tokens_lens = batch.tokens
# Replicate labels for augmented samples
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
tokens = self.hparams.wav_augment.replicate_labels(tokens)
tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens)
# CTC loss computation
loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
# WER/CER tracking (validation and test only)
if stage != sb.Stage.TRAIN:
if stage == sb.Stage.VALID:
predicted_words = self.tokenizer(
p_tokens, task="decode_from_list"
)
elif stage == sb.Stage.TEST:
predicted_words = [hyp[0].text.split(" ") for hyp in p_tokens]
target_words = undo_padding(tokens, tokens_lens)
target_words = self.tokenizer(target_words, task="decode_from_list")
self.wer_metric.append(ids, predicted_words, target_words)
self.cer_metric.append(ids, predicted_words, target_words)
return loss
Epoch-End Callbacks
on_stage_start (VALID)
def on_stage_start(self, stage, epoch):
if stage != sb.Stage.TRAIN:
self.cer_metric = self.hparams.cer_computer()
self.wer_metric = self.hparams.error_rate_computer()
on_stage_end (VALID)
def on_stage_end(self, stage, stage_loss, epoch):
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["CER"] = self.cer_metric.summarize("error_rate")
stage_stats["WER"] = self.wer_metric.summarize("error_rate")
if stage == sb.Stage.VALID:
# Learning rate annealing
old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
stage_stats["loss"]
)
old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
stage_stats["loss"]
)
sb.nnet.schedulers.update_learning_rate(
self.model_optimizer, new_lr_model
)
# Logging
self.hparams.train_logger.log_stats(
stats_meta={"epoch": epoch, "lr_model": old_lr_model},
train_stats=self.train_stats,
valid_stats=stage_stats,
)
# Save checkpoint (keep only the best by WER)
self.checkpointer.save_and_keep_only(
meta={"WER": stage_stats["WER"]},
min_keys=["WER"],
)
Dual Optimizer Setup
The CTC recipe overrides init_optimizers() to create separate optimizers:
def init_optimizers(self):
# Wav2vec2 optimizer (only if not frozen)
if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
self.modules.wav2vec2.parameters()
)
# Model optimizer (encoder DNN + CTC linear)
self.model_optimizer = self.hparams.model_opt_class(
self.hparams.model.parameters()
)
# Register with checkpointer for resumption
if self.checkpointer is not None:
self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
self.optimizers_dict = {
"model_optimizer": self.model_optimizer,
}
if not self.hparams.wav2vec2.freeze:
self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer
Warmup and Freezing
def freeze_optimizers(self, optimizers):
"""Freeze wav2vec2 optimizer during warmup phase."""
valid_optimizers = {}
if not self.hparams.wav2vec2.freeze:
if self.optimizer_step >= self.hparams.warmup_steps:
valid_optimizers["wav2vec_optimizer"] = optimizers[
"wav2vec_optimizer"
]
valid_optimizers["model_optimizer"] = optimizers["model_optimizer"]
return valid_optimizers
Usage Example
# Complete training invocation from the recipe
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = tokenizer
# Training
asr_brain.fit(
asr_brain.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=train_dataloader_opts,
valid_loader_kwargs=valid_dataloader_opts,
)
Key YAML Configuration Values
| Key | Typical Value | Description |
|---|---|---|
number_of_epochs |
30 | Maximum epochs to train |
optimizer_step_limit |
75000 | Maximum optimizer steps (early stopping) |
lr |
1.0 | Model optimizer learning rate (Adadelta) |
lr_wav2vec |
0.0001 | Wav2vec2 optimizer learning rate (AdamW) |
warmup_steps |
500 | Steps before wav2vec2 optimizer is activated |
precision |
"fp16" | Mixed precision mode |
dynamic_batching |
True | Use duration-based dynamic batching |
Dependencies
speechbrain.nnet.losses.ctc_loss-- CTC loss function wrapping PyTorch'storch.nn.functional.ctc_lossspeechbrain.decoders.ctc_greedy_decode-- greedy CTC decoding for validationspeechbrain.decoders.ctc.CTCBeamSearcher-- beam search decoding for testingspeechbrain.nnet.schedulers.NewBobScheduler-- learning rate annealingspeechbrain.utils.metric_stats.ErrorRateStats-- WER/CER metric accumulation