Implementation:Speechbrain Speechbrain Tacotron2 Inference Pipeline
| Property | Value |
|---|---|
| Type | Pattern Doc |
| Repository | speechbrain/speechbrain |
| Source File | recipes/LibriTTS/TTS/mstacotron2/train.py:L270-504 (on_stage_end with inference), L425-503 (run_inference_sample)
|
| Import | Recipe-specific (uses speechbrain.inference.vocoders.HIFIGAN for vocoder)
|
| Related Principle | Principle:Speechbrain_Speechbrain_TTS_Inference_Pipeline |
Description
The TTS inference pipeline is implemented within the Tacotron2Brain class as a combination of on_stage_end and run_inference_sample methods. It orchestrates the two-stage synthesis process: (1) autoregressive mel-spectrogram generation via Tacotron2, and (2) waveform synthesis via HiFi-GAN. The pipeline is triggered during validation and test stages to monitor training progress and produce final evaluation samples.
Key Methods
Vocoder Initialization (on_fit_start)
def on_fit_start(self):
self.hparams.progress_sample_logger.reset()
self.last_epoch = 0
self.last_batch = None
self.last_preds = None
# Load pretrained HiFi-GAN vocoder
if self.hparams.log_audio_samples:
self.vocoder = HIFIGAN.from_hparams(
source=self.hparams.vocoder,
savedir=self.hparams.vocoder_savedir,
run_opts={"device": self.device},
freeze_params=True,
)
self.last_loss_stats = {}
return super().on_fit_start()
The HiFi-GAN vocoder is loaded from HuggingFace Hub (speechbrain/tts-hifigan-libritts-16kHz) with frozen parameters (no gradient computation needed for inference).
Inference Sample Generation (run_inference_sample)
def run_inference_sample(self, stage):
"""Produces a sample in inference mode."""
if self.last_batch is None:
return
inputs, targets, _, labels, wavs, spk_embs, spk_ids = self.last_batch
text_padded, input_lengths, _, _, _ = inputs
# Stage 1: Autoregressive mel generation
mel_out, _, _ = self.hparams.model.infer(
text_padded[:1], spk_embs[:1], input_lengths[:1]
)
# Log mel-spectrogram
self.hparams.progress_sample_logger.remember(
inference_mel_out=self._get_spectrogram_sample(mel_out)
)
if stage == sb.Stage.VALID:
# Save input text
inf_sample_text = os.path.join(
self.hparams.progress_sample_path,
str(self.hparams.epoch_counter.current),
"inf_input_text.txt",
)
with open(inf_sample_text, "w", encoding="utf-8") as f:
f.write(labels[0])
# Save input audio reference
torchaudio.save(
inf_input_audio_path,
sb.dataio.dataio.read_audio(wavs[0]).unsqueeze(0),
self.hparams.sample_rate,
)
# Stage 2: Vocoder synthesis
if self.hparams.log_audio_samples:
waveform_ss = self.vocoder.decode_batch(mel_out)
torchaudio.save(
inf_sample_audio_path,
waveform_ss.squeeze(1).cpu(),
self.hparams.sample_rate,
)
Training-Time Audio Monitoring (on_stage_end)
def on_stage_end(self, stage, stage_loss, epoch):
# Training: Save samples every 10 epochs
if stage == sb.Stage.TRAIN and (
self.hparams.epoch_counter.current % 10 == 0
):
_, targets, _, labels, wavs, spk_embs, spk_ids = self.last_batch
_, mel_out_postnet, _, _, pred_mel_lengths = self.last_preds
if self.hparams.log_audio_samples:
waveform_ss = self.vocoder.decode_batch(mel_out_postnet[0])
torchaudio.save(
train_sample_audio_path,
waveform_ss.squeeze(1).cpu(),
self.hparams.sample_rate,
)
# Validation: Run inference and checkpoint
if stage == sb.Stage.VALID:
if output_progress_sample:
self.run_inference_sample(sb.Stage.VALID)
self.hparams.progress_sample_logger.save(epoch)
# Test: Final evaluation
if stage == sb.Stage.TEST:
if self.hparams.progress_samples:
self.run_inference_sample(sb.Stage.TEST)
self.hparams.progress_sample_logger.save("test")
Two-Stage Pipeline Detail
Stage 1: Tacotron2 Mel Generation
The model.infer() method performs autoregressive decoding:
- Takes encoded text
text_padded[:1](first sample from batch) - Takes speaker embedding
spk_embs[:1]for voice conditioning - Takes input length
input_lengths[:1]for attention dimensioning - Returns a tuple of
(mel_out, gate_out, alignments) - The decoder generates frames until the gate prediction exceeds
gate_threshold(0.5) ormax_decoder_steps(1500) is reached
Stage 2: HiFi-GAN Waveform Synthesis
The vocoder converts the mel-spectrogram to audio:
# mel_out shape: [1, n_mel, T] (80 channels, T frames)
waveform_ss = self.vocoder.decode_batch(mel_out)
# waveform_ss shape: [1, 1, T*hop_length] (mono audio)
The decode_batch method of the HIFIGAN inference class handles weight norm removal and the full generator forward pass internally.
Output Artifacts
For each inference sample, the pipeline saves:
| File | Description |
|---|---|
inf_input_text.txt |
The input text transcription |
inf_input_audio.wav |
The ground truth audio from the dataset |
inf_output_audio.wav |
The synthesized audio (mel -> HiFi-GAN) |
train_input_audio.wav |
Training sample ground truth audio |
train_output_audio.wav |
Training sample synthesized audio |
Files are organized by epoch under the progress_sample_path directory:
results/tacotron2/1234/samples/
10/
train_input_text.txt
train_input_audio.wav
train_output_audio.wav
inf_input_text.txt
inf_input_audio.wav
inf_output_audio.wav
20/
...
test/
...
Usage Example
Standalone Inference with Pretrained Models
import torch
import torchaudio
from speechbrain.inference.classifiers import EncoderClassifier
from speechbrain.inference.vocoders import HIFIGAN
from speechbrain.utils.text_to_sequence import text_to_sequence
device = "cuda:0"
# Load speaker encoder
spk_encoder = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spk_encoder",
run_opts={"device": device},
)
# Load vocoder
vocoder = HIFIGAN.from_hparams(
source="speechbrain/tts-hifigan-libritts-16kHz",
savedir="pretrained_models/vocoder",
run_opts={"device": device},
freeze_params=True,
)
# Extract speaker embedding from reference audio
ref_audio, ref_sr = torchaudio.load("reference_speaker.wav")
if ref_sr != 16000:
ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 16000)
spk_emb = spk_encoder.encode_batch(ref_audio.to(device))
spk_emb = spk_emb.squeeze()
# Encode text
text = "Hello, this is a test of the text to speech system."
text_seq = torch.IntTensor(
text_to_sequence(text, ["english_cleaners"])
).unsqueeze(0).to(device)
input_lengths = torch.tensor([text_seq.size(1)]).to(device)
# Stage 1: Generate mel-spectrogram (requires trained Tacotron2 model)
# mel_out, gate_out, alignments = tacotron2_model.infer(
# text_seq, spk_emb.unsqueeze(0), input_lengths
# )
# Stage 2: Convert mel to waveform
# waveform = vocoder.decode_batch(mel_out)
# torchaudio.save("output.wav", waveform.squeeze(1).cpu(), 16000)
Tensorboard Integration
When use_tensorboard is enabled, the pipeline logs:
- Audio: Both target and predicted audio for training and inference samples
- Spectrograms: Target mel, predicted mel (post-net), and inference mel as figure images
- Metrics: Loss statistics per epoch for train and validation stages
if self.hparams.use_tensorboard:
self.tensorboard_logger.log_audio(
f"{stage}/inf_audio_target", target_audio, self.hparams.sample_rate
)
self.tensorboard_logger.log_audio(
f"{stage}/inf_audio_pred", waveform_ss.squeeze(1), self.hparams.sample_rate
)
self.tensorboard_logger.log_figure(f"{stage}/inf_mel_target", targets[0][0])
self.tensorboard_logger.log_figure(f"{stage}/inf_mel_pred", mel_out)
See Also
- Principle:Speechbrain_Speechbrain_TTS_Inference_Pipeline - Theoretical foundations of the two-stage TTS inference pipeline
- Implementation:Speechbrain_Speechbrain_Tacotron2Brain_Compute_Forward - Tacotron2 training that produces the acoustic model
- Implementation:Speechbrain_Speechbrain_HifiGanBrain_Fit_Batch - HiFi-GAN training that produces the vocoder
- Implementation:Speechbrain_Speechbrain_EncoderClassifier_Encode_Batch - Speaker embedding extraction for voice conditioning