Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Tacotron2 Inference Pipeline

From Leeroopedia


Property Value
Type Pattern Doc
Repository speechbrain/speechbrain
Source File recipes/LibriTTS/TTS/mstacotron2/train.py:L270-504 (on_stage_end with inference), L425-503 (run_inference_sample)
Import Recipe-specific (uses speechbrain.inference.vocoders.HIFIGAN for vocoder)
Related Principle Principle:Speechbrain_Speechbrain_TTS_Inference_Pipeline

Description

The TTS inference pipeline is implemented within the Tacotron2Brain class as a combination of on_stage_end and run_inference_sample methods. It orchestrates the two-stage synthesis process: (1) autoregressive mel-spectrogram generation via Tacotron2, and (2) waveform synthesis via HiFi-GAN. The pipeline is triggered during validation and test stages to monitor training progress and produce final evaluation samples.

Key Methods

Vocoder Initialization (on_fit_start)

def on_fit_start(self):
    self.hparams.progress_sample_logger.reset()
    self.last_epoch = 0
    self.last_batch = None
    self.last_preds = None

    # Load pretrained HiFi-GAN vocoder
    if self.hparams.log_audio_samples:
        self.vocoder = HIFIGAN.from_hparams(
            source=self.hparams.vocoder,
            savedir=self.hparams.vocoder_savedir,
            run_opts={"device": self.device},
            freeze_params=True,
        )

    self.last_loss_stats = {}
    return super().on_fit_start()

The HiFi-GAN vocoder is loaded from HuggingFace Hub (speechbrain/tts-hifigan-libritts-16kHz) with frozen parameters (no gradient computation needed for inference).

Inference Sample Generation (run_inference_sample)

def run_inference_sample(self, stage):
    """Produces a sample in inference mode."""
    if self.last_batch is None:
        return
    inputs, targets, _, labels, wavs, spk_embs, spk_ids = self.last_batch
    text_padded, input_lengths, _, _, _ = inputs

    # Stage 1: Autoregressive mel generation
    mel_out, _, _ = self.hparams.model.infer(
        text_padded[:1], spk_embs[:1], input_lengths[:1]
    )

    # Log mel-spectrogram
    self.hparams.progress_sample_logger.remember(
        inference_mel_out=self._get_spectrogram_sample(mel_out)
    )

    if stage == sb.Stage.VALID:
        # Save input text
        inf_sample_text = os.path.join(
            self.hparams.progress_sample_path,
            str(self.hparams.epoch_counter.current),
            "inf_input_text.txt",
        )
        with open(inf_sample_text, "w", encoding="utf-8") as f:
            f.write(labels[0])

        # Save input audio reference
        torchaudio.save(
            inf_input_audio_path,
            sb.dataio.dataio.read_audio(wavs[0]).unsqueeze(0),
            self.hparams.sample_rate,
        )

        # Stage 2: Vocoder synthesis
        if self.hparams.log_audio_samples:
            waveform_ss = self.vocoder.decode_batch(mel_out)
            torchaudio.save(
                inf_sample_audio_path,
                waveform_ss.squeeze(1).cpu(),
                self.hparams.sample_rate,
            )

Training-Time Audio Monitoring (on_stage_end)

def on_stage_end(self, stage, stage_loss, epoch):
    # Training: Save samples every 10 epochs
    if stage == sb.Stage.TRAIN and (
        self.hparams.epoch_counter.current % 10 == 0
    ):
        _, targets, _, labels, wavs, spk_embs, spk_ids = self.last_batch
        _, mel_out_postnet, _, _, pred_mel_lengths = self.last_preds

        if self.hparams.log_audio_samples:
            waveform_ss = self.vocoder.decode_batch(mel_out_postnet[0])
            torchaudio.save(
                train_sample_audio_path,
                waveform_ss.squeeze(1).cpu(),
                self.hparams.sample_rate,
            )

    # Validation: Run inference and checkpoint
    if stage == sb.Stage.VALID:
        if output_progress_sample:
            self.run_inference_sample(sb.Stage.VALID)
            self.hparams.progress_sample_logger.save(epoch)

    # Test: Final evaluation
    if stage == sb.Stage.TEST:
        if self.hparams.progress_samples:
            self.run_inference_sample(sb.Stage.TEST)
            self.hparams.progress_sample_logger.save("test")

Two-Stage Pipeline Detail

Stage 1: Tacotron2 Mel Generation

The model.infer() method performs autoregressive decoding:

  1. Takes encoded text text_padded[:1] (first sample from batch)
  2. Takes speaker embedding spk_embs[:1] for voice conditioning
  3. Takes input length input_lengths[:1] for attention dimensioning
  4. Returns a tuple of (mel_out, gate_out, alignments)
  5. The decoder generates frames until the gate prediction exceeds gate_threshold (0.5) or max_decoder_steps (1500) is reached

Stage 2: HiFi-GAN Waveform Synthesis

The vocoder converts the mel-spectrogram to audio:

# mel_out shape: [1, n_mel, T] (80 channels, T frames)
waveform_ss = self.vocoder.decode_batch(mel_out)
# waveform_ss shape: [1, 1, T*hop_length] (mono audio)

The decode_batch method of the HIFIGAN inference class handles weight norm removal and the full generator forward pass internally.

Output Artifacts

For each inference sample, the pipeline saves:

File Description
inf_input_text.txt The input text transcription
inf_input_audio.wav The ground truth audio from the dataset
inf_output_audio.wav The synthesized audio (mel -> HiFi-GAN)
train_input_audio.wav Training sample ground truth audio
train_output_audio.wav Training sample synthesized audio

Files are organized by epoch under the progress_sample_path directory:

results/tacotron2/1234/samples/
  10/
    train_input_text.txt
    train_input_audio.wav
    train_output_audio.wav
    inf_input_text.txt
    inf_input_audio.wav
    inf_output_audio.wav
  20/
    ...
  test/
    ...

Usage Example

Standalone Inference with Pretrained Models

import torch
import torchaudio
from speechbrain.inference.classifiers import EncoderClassifier
from speechbrain.inference.vocoders import HIFIGAN
from speechbrain.utils.text_to_sequence import text_to_sequence

device = "cuda:0"

# Load speaker encoder
spk_encoder = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-ecapa-voxceleb",
    savedir="pretrained_models/spk_encoder",
    run_opts={"device": device},
)

# Load vocoder
vocoder = HIFIGAN.from_hparams(
    source="speechbrain/tts-hifigan-libritts-16kHz",
    savedir="pretrained_models/vocoder",
    run_opts={"device": device},
    freeze_params=True,
)

# Extract speaker embedding from reference audio
ref_audio, ref_sr = torchaudio.load("reference_speaker.wav")
if ref_sr != 16000:
    ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 16000)
spk_emb = spk_encoder.encode_batch(ref_audio.to(device))
spk_emb = spk_emb.squeeze()

# Encode text
text = "Hello, this is a test of the text to speech system."
text_seq = torch.IntTensor(
    text_to_sequence(text, ["english_cleaners"])
).unsqueeze(0).to(device)
input_lengths = torch.tensor([text_seq.size(1)]).to(device)

# Stage 1: Generate mel-spectrogram (requires trained Tacotron2 model)
# mel_out, gate_out, alignments = tacotron2_model.infer(
#     text_seq, spk_emb.unsqueeze(0), input_lengths
# )

# Stage 2: Convert mel to waveform
# waveform = vocoder.decode_batch(mel_out)
# torchaudio.save("output.wav", waveform.squeeze(1).cpu(), 16000)

Tensorboard Integration

When use_tensorboard is enabled, the pipeline logs:

  • Audio: Both target and predicted audio for training and inference samples
  • Spectrograms: Target mel, predicted mel (post-net), and inference mel as figure images
  • Metrics: Loss statistics per epoch for train and validation stages
if self.hparams.use_tensorboard:
    self.tensorboard_logger.log_audio(
        f"{stage}/inf_audio_target", target_audio, self.hparams.sample_rate
    )
    self.tensorboard_logger.log_audio(
        f"{stage}/inf_audio_pred", waveform_ss.squeeze(1), self.hparams.sample_rate
    )
    self.tensorboard_logger.log_figure(f"{stage}/inf_mel_target", targets[0][0])
    self.tensorboard_logger.log_figure(f"{stage}/inf_mel_pred", mel_out)

See Also

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment