Implementation:Speechbrain Speechbrain Train CommonVoice Seq2Seq Wav2Vec
| Knowledge Sources | |
|---|---|
| Domains | ASR, Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for training a sequence-to-sequence ASR system with a pretrained wav2vec2 encoder on the CommonVoice dataset provided by the SpeechBrain library.
Description
This training script implements a sequence-to-sequence ASR system that uses a pretrained wav2vec2 model as the encoder instead of a traditional CRDNN. The default configuration leverages the XLSR wav2vec2 model (e.g., facebook/wav2vec2-large-xlsr-53-french) for multilingual feature extraction. The decoder is GRU-based with BeamSearch, and the system is trained with joint CTC and negative log-likelihood losses on BPE sub-word units. This approach typically yields superior performance compared to training from scratch by leveraging self-supervised pretraining.
Usage
Use this script to train a wav2vec2-based seq2seq ASR model on any CommonVoice language. Run it with: python train_with_wav2vec.py hparams/train_with_wav2vec2.yaml.
Code Reference
Source Location
- Repository: SpeechBrain
- File: recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py
Signature
class ASR(sb.core.Brain):
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the output probabilities."""
...
def compute_objectives(self, predictions, batch, stage):
...
Import
import speechbrain as sb
from speechbrain.core import Brain
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hparams_file | str | Yes | Path to the YAML hyperparameter configuration file (e.g., hparams/train_with_wav2vec2.yaml) |
| batch.sig | tuple | Yes | Waveform tensor and lengths from the dataloader |
| batch.tokens_bos | tuple | Yes | BPE tokens with beginning-of-sequence marker |
| batch.tokens_eos | tuple | Yes | BPE tokens with end-of-sequence marker |
| batch.tokens | tuple | Yes | BPE tokens without special markers (for CTC) |
Outputs
| Name | Type | Description |
|---|---|---|
| p_ctc | tensor | CTC log-probabilities over the token vocabulary |
| p_seq | tensor | Seq2seq output probabilities from the decoder |
| wav_lens | tensor | Relative lengths of the input waveforms |
| model checkpoint | file | Saved model parameters including fine-tuned wav2vec2 weights |
| WER/CER metrics | float | Word error rate and character error rate on dev/test sets |
Usage Examples
# Command-line usage
# python train_with_wav2vec.py hparams/train_with_wav2vec2.yaml
# Programmatic usage
import sys
from hyperpyyaml import load_hyperpyyaml
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
opt_class=hparams["opt_class"],
checkpointer=hparams["checkpointer"],
)
asr_brain.fit(
hparams["epoch_counter"],
train_data,
valid_data,
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["test_dataloader_options"],
)