Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Tacotron2Brain Compute Forward

From Leeroopedia


Property Value
Type API Doc
Repository speechbrain/speechbrain
Source File recipes/LibriTTS/TTS/mstacotron2/train.py:L32 (class), L56-79 (compute_forward), L81-96 (fit_batch), L98-159 (compute_objectives)
Import Recipe-specific Brain subclass (not directly importable as a library)
Related Principle Principle:Speechbrain_Speechbrain_Tacotron2_Acoustic_Model_Training

Class Definition

class Tacotron2Brain(sb.Brain):
    """The Brain implementation for Tacotron2"""

Tacotron2Brain extends SpeechBrain's Brain class with custom forward, loss, and training logic for the Zero-Shot Multi-Speaker Tacotron2 model.

API Signatures

compute_forward

def compute_forward(self, batch, stage):
    """Computes the forward pass

    Arguments
    ---------
    batch: str
        a single batch
    stage: speechbrain.Stage
        the training stage

    Returns
    -------
    the model output
    """

fit_batch

def fit_batch(self, batch):
    """Fits a single batch and applies annealing

    Arguments
    ---------
    batch: tuple
        a training batch

    Returns
    -------
    loss: torch.Tensor
        detached loss
    """

compute_objectives

def compute_objectives(self, predictions, batch, stage):
    """Computes the loss given the predicted and targeted outputs

    Arguments
    ---------
    predictions : torch.Tensor
        The model generated mel-spectrograms and other metrics
    batch : PaddedBatch
        This batch object contains all the relevant tensors
    stage : sb.Stage
        One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST

    Returns
    -------
    loss : torch.Tensor
        A one-element tensor used for backpropagating the gradient
    """

Description

The Tacotron2Brain class orchestrates the full training loop for the multi-speaker Tacotron2 acoustic model. It handles batch unpacking, forward pass computation through the Tacotron2 model, loss calculation, learning rate scheduling, and progress monitoring.

Batch Structure

The input batch is a complex tuple containing 10 elements, unpacked by batch_to_device:

(
    text_padded,      # Encoded text tokens [batch, max_text_len]
    input_lengths,    # Text sequence lengths [batch]
    mel_padded,       # Target mel-spectrogram [batch, n_mel, max_mel_len]
    gate_padded,      # Target gate (stop) values [batch, max_mel_len]
    output_lengths,   # Mel sequence lengths [batch]
    len_x,            # Total output length (sum of all output_lengths)
    labels,           # Raw text strings (list of str)
    wavs,             # Paths to source audio files (list of str)
    spk_embs,         # Precomputed speaker embeddings [batch, 192]
    spk_ids,          # Speaker ID strings (list of str)
)

The batch is assembled by the TextMelCollate class, which loads speaker embeddings from the precomputed pickle file.

Forward Pass

The compute_forward method performs the following steps:

def compute_forward(self, batch, stage):
    effective_batch = self.batch_to_device(batch)
    inputs, y, num_items, _, _, spk_embs, spk_ids = effective_batch

    _, input_lengths, _, _, _ = inputs

    max_input_length = input_lengths.max().item()

    return self.modules.model(
        inputs, spk_embs, alignments_dim=max_input_length
    )

The model outputs a tuple of:

  • mel_out: Pre-net mel-spectrogram predictions [batch, n_mel, T]
  • mel_out_postnet: Post-net refined mel-spectrogram [batch, n_mel, T]
  • gate_out: Stop token predictions [batch, T]
  • alignments: Attention alignment matrices [batch, T_out, T_in]
  • pred_mel_lengths: Predicted mel lengths [batch]

Loss Computation

The compute_objectives method delegates to _compute_loss, which uses the configured MSTacotron2.Loss criterion:

def _compute_loss(self, predictions, batch, stage):
    inputs, targets, num_items, labels, wavs, spk_embs, spk_ids = batch
    text_padded, input_lengths, _, max_len, output_lengths = inputs

    loss_stats = self.hparams.criterion(
        predictions,
        targets,
        input_lengths,
        output_lengths,
        spk_emb_input=None,
        current_epoch=self.last_epoch,
    )
    self.last_loss_stats[stage] = scalarize(loss_stats)
    return loss_stats.loss

The loss combines:

  • Mel MSE loss: On both pre-net and post-net outputs
  • Gate BCE loss: Weighted by gate_loss_weight
  • Guided attention loss: With scheduled weight decay

Custom fit_batch

The fit_batch method extends the parent with learning rate annealing:

def fit_batch(self, batch):
    result = super().fit_batch(batch)
    self.hparams.lr_annealing(self.optimizer)
    return result

The Noam scheduler warms up the learning rate over 4000 steps and then applies inverse square root decay.

Usage Example

Complete Training Script

import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

# Parse arguments and load hyperparameters
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)

# Prepare data (run on main process only)
from libritts_prepare import prepare_libritts
sb.utils.distributed.run_on_main(
    prepare_libritts,
    kwargs={
        "data_folder": hparams["data_folder"],
        "save_json_train": hparams["train_json"],
        "save_json_valid": hparams["valid_json"],
        "save_json_test": hparams["test_json"],
        "sample_rate": hparams["sample_rate"],
        "train_split": hparams["train_split"],
        "valid_split": hparams["valid_split"],
        "test_split": hparams["test_split"],
        "seed": hparams["seed"],
        "model_name": hparams["model"].__class__.__name__,
    },
)

# Compute speaker embeddings
from compute_speaker_embeddings import compute_speaker_embeddings
sb.utils.distributed.run_on_main(
    compute_speaker_embeddings,
    kwargs={
        "input_filepaths": [hparams["train_json"], hparams["valid_json"], hparams["test_json"]],
        "output_file_paths": [
            hparams["train_speaker_embeddings_pickle"],
            hparams["valid_speaker_embeddings_pickle"],
            hparams["test_speaker_embeddings_pickle"],
        ],
        "data_folder": hparams["data_folder"],
        "spk_emb_encoder_path": hparams["spk_emb_encoder"],
        "spk_emb_sr": hparams["spk_emb_sample_rate"],
        "mel_spec_params": {"custom_mel_spec_encoder": False, ...},
        "device": run_opts["device"],
    },
)

# Prepare datasets
datasets = dataio_prepare(hparams)

# Initialize Tacotron2Brain
tacotron2_brain = Tacotron2Brain(
    modules=hparams["modules"],
    opt_class=hparams["opt_class"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Train
tacotron2_brain.fit(
    tacotron2_brain.hparams.epoch_counter,
    train_set=datasets["train"],
    valid_set=datasets["valid"],
    train_loader_kwargs=hparams["train_dataloader_opts"],
    valid_loader_kwargs=hparams["valid_dataloader_opts"],
)

# Test
tacotron2_brain.evaluate(
    datasets["test"],
    test_loader_kwargs=hparams["test_dataloader_opts"],
)

Command-Line Invocation

python train.py --device=cuda:0 --max_grad_norm=1.0 \
    --data_folder=/path/to/LibriTTS \
    hparams/train.yaml

Progress Monitoring

The on_stage_end method provides rich monitoring:

  • Every 10 epochs: Saves training sample spectrograms, alignment visualizations, input text, input audio, and (optionally) synthesized audio
  • Every progress_samples_interval epochs: Runs inference on a validation sample and saves the generated mel-spectrogram and audio
  • Checkpointing: Saves model checkpoints based on validation loss, with optional interval-based retention (e.g., keep every 50th epoch)

Key YAML Configuration

# Model definition
model: !new:speechbrain.lobes.models.MSTacotron2.Tacotron2
  mask_padding: True
  n_mel_channels: 80
  n_symbols: 148
  symbols_embedding_dim: 1024
  encoder_kernel_size: 5
  encoder_n_convolutions: 6
  encoder_embedding_dim: 1024
  attention_rnn_dim: 2048
  attention_dim: 256
  decoder_rnn_dim: 2048
  prenet_dim: 512
  max_decoder_steps: 1500
  gate_threshold: 0.5
  postnet_embedding_dim: 1024
  postnet_kernel_size: 5
  postnet_n_convolutions: 10
  spk_emb_size: 192

# Loss function
criterion: !new:speechbrain.lobes.models.MSTacotron2.Loss
  gate_loss_weight: 1.0
  guided_attention_weight: 25.0
  guided_attention_sigma: 0.2

# Optimizer with Noam scheduling
opt_class: !name:torch.optim.Adam
  lr: 0.001
  weight_decay: 0.000006

lr_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
  lr_initial: 0.001
  n_warmup_steps: 4000

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment