Implementation:Speechbrain Speechbrain Whisper Dataio Prepare
Appearance
| Field | Value |
|---|---|
| API | dataio_prepare(hparams, tokenizer) -> tuple(DynamicItemDataset, DynamicItemDataset, DynamicItemDataset) |
| Source | recipes/CommonVoice/ASR/transformer/train_with_whisper.py:L160-252 |
| Import | Recipe-specific function (defined in train_with_whisper.py) |
| Type | API Doc |
| Inputs | hparams dict with CSV paths and sorting configuration; Whisper tokenizer from the loaded model |
| Outputs | Three DynamicItemDataset instances (train, valid, test) with output keys: ["id", "sig", "tokens_list", "tokens_bos", "tokens_eos", "tokens"] |
| Related Principle | Principle:Speechbrain_Speechbrain_Whisper_Data_Tokenization_Pipeline |
Purpose
Constructs the SpeechBrain data pipeline for Whisper fine-tuning, including audio loading with resampling and text tokenization using Whisper's built-in byte-level BPE tokenizer with proper special token handling.
Signature
def dataio_prepare(hparams, tokenizer):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined functions.
Arguments
---------
hparams : dict
Hyperparameters dictionary loaded from YAML.
tokenizer : WhisperTokenizer
The Whisper tokenizer obtained from the loaded model.
Returns
-------
train_data : DynamicItemDataset
valid_data : DynamicItemDataset
test_data : DynamicItemDataset
"""
Parameters
| Parameter | Type | Description |
|---|---|---|
| hparams | dict | Must contain: train_csv, valid_csv, test_csv (CSV file paths), data_folder (base path for audio), sample_rate (target sample rate, typically 16000), sorting ("ascending", "descending", or "random"), avoid_if_longer_than (max duration filter), train_loader_kwargs (DataLoader config) |
| tokenizer | WhisperTokenizer | The Whisper tokenizer obtained via hparams["whisper"].tokenizer. Contains encode(), decode(), normalize(), and build_inputs_with_special_tokens() methods |
Usage Example
from hyperpyyaml import load_hyperpyyaml
# Load hyperparameters
with open("hparams/train_hf_whisper.yaml") as fin:
hparams = load_hyperpyyaml(fin)
# Get tokenizer from the Whisper model
tokenizer = hparams["whisper"].tokenizer
# Prepare datasets
train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
# Access a sample
sample = train_data[0]
print(sample["id"]) # utterance ID string
print(sample["sig"].shape) # audio tensor, e.g., torch.Size([48000])
print(sample["tokens_bos"]) # e.g., tensor([50258, 50265, 50360, 50364, 35309, 531])
print(sample["tokens_eos"]) # e.g., tensor([50265, 50360, 50364, 35309, 531, 50257])
print(sample["tokens_list"]) # e.g., [35309, 531, 32828]
Audio Pipeline
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
info = torchaudio.info(wav)
sig = sb.dataio.dataio.read_audio(wav)
if info.sample_rate != hparams["sample_rate"]:
sig = torchaudio.transforms.Resample(
info.sample_rate, hparams["sample_rate"]
)(sig)
return sig
The audio pipeline:
- Reads the audio file path from the wav column of the CSV.
- Loads the audio waveform using SpeechBrain's read_audio.
- Resamples to the target sample rate (16000 Hz) if necessary.
- Provides the sig tensor as output.
Text Pipeline
@sb.utils.data_pipeline.takes("wrd")
@sb.utils.data_pipeline.provides(
"wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
)
def text_pipeline(wrd):
if hasattr(hparams, "normalized_transcripts"):
wrd = tokenizer.normalize(wrd)
yield wrd
tokens_list = tokenizer.encode(wrd, add_special_tokens=False)
yield tokens_list
tokens_list = tokenizer.build_inputs_with_special_tokens(tokens_list)
tokens_bos = torch.LongTensor(tokens_list[:-1])
yield tokens_bos
tokens_eos = torch.LongTensor(tokens_list[1:])
yield tokens_eos
tokens = torch.LongTensor(tokens_list)
yield tokens
The text pipeline:
- Takes the wrd field from the CSV.
- Optionally normalizes the text using Whisper's normalizer.
- Encodes text to token IDs (without special tokens) using tokenizer.encode().
- Wraps with special tokens using tokenizer.build_inputs_with_special_tokens(), which adds [<|startoftranscript|>, <|language|>, <|task|>, <|notimestamps|>] at the beginning and [<|endoftext|>] at the end.
- Creates tokens_bos: all tokens except the last (for decoder input during teacher forcing).
- Creates tokens_eos: all tokens except the first (for loss target).
- Creates tokens: the complete token sequence (for evaluation).
Dataset Sorting
The function supports three sorting modes for the training set:
- ascending: Sorts by duration (shortest first). Speeds up training by batching similar-length utterances together. Disables DataLoader shuffling.
- descending: Sorts by duration (longest first). Useful for specific training strategies. Disables DataLoader shuffling.
- random: No sorting; uses DataLoader shuffling.
Utterances longer than avoid_if_longer_than seconds (default: 10.0) are filtered out.
Validation data is always sorted by duration (ascending). Test data is unsorted.
Output Keys
| Key | Type | Description |
|---|---|---|
| id | str | Utterance identifier from the CSV |
| sig | torch.Tensor | Audio waveform resampled to 16kHz |
| tokens_list | list[int] | Raw token IDs without special tokens |
| tokens_bos | torch.LongTensor | BOS-prepended sequence for decoder input |
| tokens_eos | torch.LongTensor | EOS-appended sequence for loss computation |
| tokens | torch.LongTensor | Full token sequence with all special tokens |
See Also
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment