Implementation:Speechbrain Speechbrain Tacotron2Brain Compute Forward
| Property | Value |
|---|---|
| Type | API Doc |
| Repository | speechbrain/speechbrain |
| Source File | recipes/LibriTTS/TTS/mstacotron2/train.py:L32 (class), L56-79 (compute_forward), L81-96 (fit_batch), L98-159 (compute_objectives)
|
| Import | Recipe-specific Brain subclass (not directly importable as a library) |
| Related Principle | Principle:Speechbrain_Speechbrain_Tacotron2_Acoustic_Model_Training |
Class Definition
class Tacotron2Brain(sb.Brain):
"""The Brain implementation for Tacotron2"""
Tacotron2Brain extends SpeechBrain's Brain class with custom forward, loss, and training logic for the Zero-Shot Multi-Speaker Tacotron2 model.
API Signatures
compute_forward
def compute_forward(self, batch, stage):
"""Computes the forward pass
Arguments
---------
batch: str
a single batch
stage: speechbrain.Stage
the training stage
Returns
-------
the model output
"""
fit_batch
def fit_batch(self, batch):
"""Fits a single batch and applies annealing
Arguments
---------
batch: tuple
a training batch
Returns
-------
loss: torch.Tensor
detached loss
"""
compute_objectives
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss given the predicted and targeted outputs
Arguments
---------
predictions : torch.Tensor
The model generated mel-spectrograms and other metrics
batch : PaddedBatch
This batch object contains all the relevant tensors
stage : sb.Stage
One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST
Returns
-------
loss : torch.Tensor
A one-element tensor used for backpropagating the gradient
"""
Description
The Tacotron2Brain class orchestrates the full training loop for the multi-speaker Tacotron2 acoustic model. It handles batch unpacking, forward pass computation through the Tacotron2 model, loss calculation, learning rate scheduling, and progress monitoring.
Batch Structure
The input batch is a complex tuple containing 10 elements, unpacked by batch_to_device:
(
text_padded, # Encoded text tokens [batch, max_text_len]
input_lengths, # Text sequence lengths [batch]
mel_padded, # Target mel-spectrogram [batch, n_mel, max_mel_len]
gate_padded, # Target gate (stop) values [batch, max_mel_len]
output_lengths, # Mel sequence lengths [batch]
len_x, # Total output length (sum of all output_lengths)
labels, # Raw text strings (list of str)
wavs, # Paths to source audio files (list of str)
spk_embs, # Precomputed speaker embeddings [batch, 192]
spk_ids, # Speaker ID strings (list of str)
)
The batch is assembled by the TextMelCollate class, which loads speaker embeddings from the precomputed pickle file.
Forward Pass
The compute_forward method performs the following steps:
def compute_forward(self, batch, stage):
effective_batch = self.batch_to_device(batch)
inputs, y, num_items, _, _, spk_embs, spk_ids = effective_batch
_, input_lengths, _, _, _ = inputs
max_input_length = input_lengths.max().item()
return self.modules.model(
inputs, spk_embs, alignments_dim=max_input_length
)
The model outputs a tuple of:
- mel_out: Pre-net mel-spectrogram predictions
[batch, n_mel, T] - mel_out_postnet: Post-net refined mel-spectrogram
[batch, n_mel, T] - gate_out: Stop token predictions
[batch, T] - alignments: Attention alignment matrices
[batch, T_out, T_in] - pred_mel_lengths: Predicted mel lengths
[batch]
Loss Computation
The compute_objectives method delegates to _compute_loss, which uses the configured MSTacotron2.Loss criterion:
def _compute_loss(self, predictions, batch, stage):
inputs, targets, num_items, labels, wavs, spk_embs, spk_ids = batch
text_padded, input_lengths, _, max_len, output_lengths = inputs
loss_stats = self.hparams.criterion(
predictions,
targets,
input_lengths,
output_lengths,
spk_emb_input=None,
current_epoch=self.last_epoch,
)
self.last_loss_stats[stage] = scalarize(loss_stats)
return loss_stats.loss
The loss combines:
- Mel MSE loss: On both pre-net and post-net outputs
- Gate BCE loss: Weighted by
gate_loss_weight - Guided attention loss: With scheduled weight decay
Custom fit_batch
The fit_batch method extends the parent with learning rate annealing:
def fit_batch(self, batch):
result = super().fit_batch(batch)
self.hparams.lr_annealing(self.optimizer)
return result
The Noam scheduler warms up the learning rate over 4000 steps and then applies inverse square root decay.
Usage Example
Complete Training Script
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
# Parse arguments and load hyperparameters
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# Prepare data (run on main process only)
from libritts_prepare import prepare_libritts
sb.utils.distributed.run_on_main(
prepare_libritts,
kwargs={
"data_folder": hparams["data_folder"],
"save_json_train": hparams["train_json"],
"save_json_valid": hparams["valid_json"],
"save_json_test": hparams["test_json"],
"sample_rate": hparams["sample_rate"],
"train_split": hparams["train_split"],
"valid_split": hparams["valid_split"],
"test_split": hparams["test_split"],
"seed": hparams["seed"],
"model_name": hparams["model"].__class__.__name__,
},
)
# Compute speaker embeddings
from compute_speaker_embeddings import compute_speaker_embeddings
sb.utils.distributed.run_on_main(
compute_speaker_embeddings,
kwargs={
"input_filepaths": [hparams["train_json"], hparams["valid_json"], hparams["test_json"]],
"output_file_paths": [
hparams["train_speaker_embeddings_pickle"],
hparams["valid_speaker_embeddings_pickle"],
hparams["test_speaker_embeddings_pickle"],
],
"data_folder": hparams["data_folder"],
"spk_emb_encoder_path": hparams["spk_emb_encoder"],
"spk_emb_sr": hparams["spk_emb_sample_rate"],
"mel_spec_params": {"custom_mel_spec_encoder": False, ...},
"device": run_opts["device"],
},
)
# Prepare datasets
datasets = dataio_prepare(hparams)
# Initialize Tacotron2Brain
tacotron2_brain = Tacotron2Brain(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# Train
tacotron2_brain.fit(
tacotron2_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["train_dataloader_opts"],
valid_loader_kwargs=hparams["valid_dataloader_opts"],
)
# Test
tacotron2_brain.evaluate(
datasets["test"],
test_loader_kwargs=hparams["test_dataloader_opts"],
)
Command-Line Invocation
python train.py --device=cuda:0 --max_grad_norm=1.0 \
--data_folder=/path/to/LibriTTS \
hparams/train.yaml
Progress Monitoring
The on_stage_end method provides rich monitoring:
- Every 10 epochs: Saves training sample spectrograms, alignment visualizations, input text, input audio, and (optionally) synthesized audio
- Every
progress_samples_intervalepochs: Runs inference on a validation sample and saves the generated mel-spectrogram and audio - Checkpointing: Saves model checkpoints based on validation loss, with optional interval-based retention (e.g., keep every 50th epoch)
Key YAML Configuration
# Model definition
model: !new:speechbrain.lobes.models.MSTacotron2.Tacotron2
mask_padding: True
n_mel_channels: 80
n_symbols: 148
symbols_embedding_dim: 1024
encoder_kernel_size: 5
encoder_n_convolutions: 6
encoder_embedding_dim: 1024
attention_rnn_dim: 2048
attention_dim: 256
decoder_rnn_dim: 2048
prenet_dim: 512
max_decoder_steps: 1500
gate_threshold: 0.5
postnet_embedding_dim: 1024
postnet_kernel_size: 5
postnet_n_convolutions: 10
spk_emb_size: 192
# Loss function
criterion: !new:speechbrain.lobes.models.MSTacotron2.Loss
gate_loss_weight: 1.0
guided_attention_weight: 25.0
guided_attention_sigma: 0.2
# Optimizer with Noam scheduling
opt_class: !name:torch.optim.Adam
lr: 0.001
weight_decay: 0.000006
lr_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
lr_initial: 0.001
n_warmup_steps: 4000
See Also
- Principle:Speechbrain_Speechbrain_Tacotron2_Acoustic_Model_Training - Theoretical foundations of Tacotron2 training
- Implementation:Speechbrain_Speechbrain_EncoderClassifier_Encode_Batch - Speaker embedding extraction used by this training recipe
- Implementation:Speechbrain_Speechbrain_Tacotron2_Inference_Pipeline - Inference pipeline using the trained model