Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Speechbrain Speechbrain Whisper Dataio Prepare

From Leeroopedia


Field Value
API dataio_prepare(hparams, tokenizer) -> tuple(DynamicItemDataset, DynamicItemDataset, DynamicItemDataset)
Source recipes/CommonVoice/ASR/transformer/train_with_whisper.py:L160-252
Import Recipe-specific function (defined in train_with_whisper.py)
Type API Doc
Inputs hparams dict with CSV paths and sorting configuration; Whisper tokenizer from the loaded model
Outputs Three DynamicItemDataset instances (train, valid, test) with output keys: ["id", "sig", "tokens_list", "tokens_bos", "tokens_eos", "tokens"]
Related Principle Principle:Speechbrain_Speechbrain_Whisper_Data_Tokenization_Pipeline

Purpose

Constructs the SpeechBrain data pipeline for Whisper fine-tuning, including audio loading with resampling and text tokenization using Whisper's built-in byte-level BPE tokenizer with proper special token handling.

Signature

def dataio_prepare(hparams, tokenizer):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions.

    Arguments
    ---------
    hparams : dict
        Hyperparameters dictionary loaded from YAML.
    tokenizer : WhisperTokenizer
        The Whisper tokenizer obtained from the loaded model.

    Returns
    -------
    train_data : DynamicItemDataset
    valid_data : DynamicItemDataset
    test_data : DynamicItemDataset
    """

Parameters

Parameter Type Description
hparams dict Must contain: train_csv, valid_csv, test_csv (CSV file paths), data_folder (base path for audio), sample_rate (target sample rate, typically 16000), sorting ("ascending", "descending", or "random"), avoid_if_longer_than (max duration filter), train_loader_kwargs (DataLoader config)
tokenizer WhisperTokenizer The Whisper tokenizer obtained via hparams["whisper"].tokenizer. Contains encode(), decode(), normalize(), and build_inputs_with_special_tokens() methods

Usage Example

from hyperpyyaml import load_hyperpyyaml

# Load hyperparameters
with open("hparams/train_hf_whisper.yaml") as fin:
    hparams = load_hyperpyyaml(fin)

# Get tokenizer from the Whisper model
tokenizer = hparams["whisper"].tokenizer

# Prepare datasets
train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)

# Access a sample
sample = train_data[0]
print(sample["id"])          # utterance ID string
print(sample["sig"].shape)   # audio tensor, e.g., torch.Size([48000])
print(sample["tokens_bos"])  # e.g., tensor([50258, 50265, 50360, 50364, 35309, 531])
print(sample["tokens_eos"])  # e.g., tensor([50265, 50360, 50364, 35309, 531, 50257])
print(sample["tokens_list"]) # e.g., [35309, 531, 32828]

Audio Pipeline

@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
    info = torchaudio.info(wav)
    sig = sb.dataio.dataio.read_audio(wav)
    if info.sample_rate != hparams["sample_rate"]:
        sig = torchaudio.transforms.Resample(
            info.sample_rate, hparams["sample_rate"]
        )(sig)
    return sig

The audio pipeline:

  • Reads the audio file path from the wav column of the CSV.
  • Loads the audio waveform using SpeechBrain's read_audio.
  • Resamples to the target sample rate (16000 Hz) if necessary.
  • Provides the sig tensor as output.

Text Pipeline

@sb.utils.data_pipeline.takes("wrd")
@sb.utils.data_pipeline.provides(
    "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
)
def text_pipeline(wrd):
    if hasattr(hparams, "normalized_transcripts"):
        wrd = tokenizer.normalize(wrd)
    yield wrd
    tokens_list = tokenizer.encode(wrd, add_special_tokens=False)
    yield tokens_list
    tokens_list = tokenizer.build_inputs_with_special_tokens(tokens_list)
    tokens_bos = torch.LongTensor(tokens_list[:-1])
    yield tokens_bos
    tokens_eos = torch.LongTensor(tokens_list[1:])
    yield tokens_eos
    tokens = torch.LongTensor(tokens_list)
    yield tokens

The text pipeline:

  • Takes the wrd field from the CSV.
  • Optionally normalizes the text using Whisper's normalizer.
  • Encodes text to token IDs (without special tokens) using tokenizer.encode().
  • Wraps with special tokens using tokenizer.build_inputs_with_special_tokens(), which adds [<|startoftranscript|>, <|language|>, <|task|>, <|notimestamps|>] at the beginning and [<|endoftext|>] at the end.
  • Creates tokens_bos: all tokens except the last (for decoder input during teacher forcing).
  • Creates tokens_eos: all tokens except the first (for loss target).
  • Creates tokens: the complete token sequence (for evaluation).

Dataset Sorting

The function supports three sorting modes for the training set:

  • ascending: Sorts by duration (shortest first). Speeds up training by batching similar-length utterances together. Disables DataLoader shuffling.
  • descending: Sorts by duration (longest first). Useful for specific training strategies. Disables DataLoader shuffling.
  • random: No sorting; uses DataLoader shuffling.

Utterances longer than avoid_if_longer_than seconds (default: 10.0) are filtered out.

Validation data is always sorted by duration (ascending). Test data is unsorted.

Output Keys

Key Type Description
id str Utterance identifier from the CSV
sig torch.Tensor Audio waveform resampled to 16kHz
tokens_list list[int] Raw token IDs without special tokens
tokens_bos torch.LongTensor BOS-prepended sequence for decoder input
tokens_eos torch.LongTensor EOS-appended sequence for loss computation
tokens torch.LongTensor Full token sequence with all special tokens

See Also

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment