Implementation:Speechbrain Speechbrain Train Wav2Vec2 SSL
| Knowledge Sources | |
|---|---|
| Domains | ASR, Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for self-supervised pre-training of a wav2vec 2.0 model on the CommonVoice dataset provided by the SpeechBrain library.
Description
This recipe defines the W2VBrain class (subclass of sb.core.Brain) for wav2vec 2.0 self-supervised pre-training. It wraps the HuggingFace Transformers implementation of wav2vec 2.0 within the SpeechBrain framework. The model uses contrastive learning to learn speech representations from unlabeled audio. During training, the forward pass returns only the contrastive loss from the HuggingFace model; during evaluation, cosine similarity between projected states and quantized states is computed as an accuracy metric.
Usage
Use this recipe to pre-train a wav2vec 2.0 model on CommonVoice (or any dataset with a JSON/CSV manifest). The resulting model can then be used as a pre-trained encoder for downstream ASR tasks.
Code Reference
Source Location
- Repository: SpeechBrain
- File: recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py
Signature
class W2VBrain(sb.core.Brain):
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the w2v2 loss."""
...
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss (CTC+NLL) given predictions and targets."""
...
Import
# Run as recipe script
python recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py hparams/hyperparams.yaml --data_folder /path/to/commonvoice
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| batch.sig | torch.Tensor | Yes | Input waveform signal |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Contrastive loss from wav2vec 2.0 |
| out | Wav2Vec2ForPreTrainingOutput | Model output including projected and quantized states (eval only) |
| mask | torch.Tensor | Mask time indices used for contrastive loss (eval only) |
Usage Examples
python train_hf_wav2vec2.py hparams/hyperparams.yaml --data_folder /path/to/CommonVoice