Implementation:Speechbrain Speechbrain HifiGanBrain Fit Batch
| Property | Value |
|---|---|
| Type | API Doc |
| Repository | speechbrain/speechbrain |
| Source File | recipes/LibriTTS/vocoder/hifigan/train.py:L25 (class), L26-58 (compute_forward), L59-81 (compute_objectives), L82-111 (fit_batch)
|
| Import | Recipe-specific Brain subclass (not directly importable as a library) |
| Related Principle | Principle:Speechbrain_Speechbrain_HiFi_GAN_Vocoder_Training |
Class Definition
class HifiGanBrain(sb.Brain):
"""Brain class for HiFi-GAN vocoder training with adversarial loss"""
HifiGanBrain extends SpeechBrain's Brain class to implement the GAN training loop, with separate generator and discriminator optimization steps within each batch.
API Signatures
compute_forward
def compute_forward(self, batch, stage):
"""Generates synthesized waveforms and computes discriminator scores.
Arguments
---------
batch : PaddedBatch
a single batch
stage : speechbrain.Stage
the training stage
Returns
-------
y_g_hat : torch.Tensor
Generated waveform
scores_fake : torch.Tensor
Discriminator scores for generated audio
feats_fake : torch.Tensor
Discriminator intermediate features for generated audio
scores_real : torch.Tensor
Discriminator scores for real audio
feats_real : torch.Tensor
Discriminator intermediate features for real audio
"""
compute_objectives
def compute_objectives(self, predictions, batch, stage):
"""Computes and combines generator and discriminator losses.
Returns
-------
loss : dict
Dictionary containing 'G_loss' and 'D_loss' keys
"""
fit_batch
def fit_batch(self, batch):
"""Train discriminator and generator adversarially.
Returns
-------
loss : torch.Tensor
Generator loss (detached)
"""
Description
The HifiGanBrain class implements the complete adversarial training loop for the HiFi-GAN vocoder. The key distinction from standard SpeechBrain Brain subclasses is the alternating generator/discriminator training within fit_batch, with separate optimizers, loss functions, and learning rate schedulers for each network.
Forward Pass
The compute_forward method performs three operations:
def compute_forward(self, batch, stage):
batch = batch.to(self.device)
x, _ = batch.mel # Mel-spectrogram input
y, _ = batch.sig # Ground truth waveform
# Generate synthesized waveform from mel
y_g_hat = self.modules.generator(x)[:, :, :y.size(2)]
# Score both real and fake with discriminator
scores_fake, feats_fake = self.modules.discriminator(y_g_hat.detach())
scores_real, feats_real = self.modules.discriminator(y)
return (y_g_hat, scores_fake, feats_fake, scores_real, feats_real)
The generated waveform is truncated to match the target length via [:, :, :y.size(2)].
Adversarial Training Loop
The fit_batch method implements the alternating optimization strategy:
def fit_batch(self, batch):
batch = batch.to(self.device)
y, _ = batch.sig
# Forward pass
outputs = self.compute_forward(batch, sb.core.Stage.TRAIN)
(y_g_hat, scores_fake, feats_fake, scores_real, feats_real) = outputs
# --- Step 1: Train Discriminator ---
loss_d = self.compute_objectives(outputs, batch, sb.core.Stage.TRAIN)["D_loss"]
self.optimizer_d.zero_grad()
loss_d.backward()
self.optimizer_d.step()
# --- Step 2: Train Generator ---
# Re-score with updated discriminator
scores_fake, feats_fake = self.modules.discriminator(y_g_hat)
scores_real, feats_real = self.modules.discriminator(y)
outputs = (y_g_hat, scores_fake, feats_fake, scores_real, feats_real)
loss_g = self.compute_objectives(outputs, batch, sb.core.Stage.TRAIN)["G_loss"]
self.optimizer_g.zero_grad()
loss_g.backward()
self.optimizer_g.step()
return loss_g.detach().cpu()
Key detail: After updating the discriminator, the generator's output is re-scored with the updated discriminator before computing the generator loss. This ensures the generator trains against the most up-to-date discriminator.
Loss Functions
Generator Loss
The generator loss combines multiple components configured in the YAML:
loss_g = self.hparams.generator_loss(
stage, y_hat, y, scores_fake, feats_fake, feats_real
)
Components and their weights:
- L1 mel-spectrogram loss (
l1_spec_loss): Weight 45 - L1 distance between mel-spectrograms of real and generated audio - Feature matching loss (
feat_match_loss): Weight 10 - L1 distance between discriminator intermediate features - MSE adversarial loss (
mseg_loss): Weight 1 - MSE between discriminator scores for generated audio and target 1.0
Discriminator Loss
loss_d = self.hparams.discriminator_loss(scores_fake, scores_real)
- MSE discriminator loss (
msed_loss): MSE between real scores (target: 1) and fake scores (target: 0)
Dual Optimizer Initialization
The init_optimizers method sets up separate optimizers and schedulers:
def init_optimizers(self):
(opt_g_class, opt_d_class, sch_g_class, sch_d_class) = self.opt_class
self.optimizer_g = opt_g_class(self.modules.generator.parameters())
self.optimizer_d = opt_d_class(self.modules.discriminator.parameters())
self.scheduler_g = sch_g_class(self.optimizer_g)
self.scheduler_d = sch_d_class(self.optimizer_d)
Both use AdamW with learning rate 0.0002 and ExponentialLR decay (gamma=0.9999).
Batch Structure
The batch contains two dynamic items provided by the data pipeline:
@sb.utils.data_pipeline.takes("wav", "segment")
@sb.utils.data_pipeline.provides("mel", "sig")
def audio_pipeline(wav, segment):
audio = sb.dataio.dataio.read_audio(wav)
audio = torch.FloatTensor(audio).unsqueeze(0)
# Random segment extraction for training
if segment:
if audio.size(1) >= segment_size:
max_audio_start = audio.size(1) - segment_size
audio_start = torch.randint(0, max_audio_start, (1,))
audio = audio[:, audio_start:audio_start + segment_size]
else:
audio = torch.nn.functional.pad(
audio, (0, segment_size - audio.size(1)), "constant"
)
mel = hparams["mel_spectogram"](audio=audio.squeeze(0))
return mel, audio
- mel: Mel-spectrogram computed on-the-fly from the audio segment
- sig: Raw waveform tensor (8192 samples for training segments)
Usage Example
Complete Training Script
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
hparams = load_hyperpyyaml(fin, overrides)
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# Data preparation
from libritts_prepare import prepare_libritts
sb.utils.distributed.run_on_main(
prepare_libritts,
kwargs={
"data_folder": hparams["data_folder"],
"save_json_train": hparams["train_json"],
"save_json_valid": hparams["valid_json"],
"save_json_test": hparams["test_json"],
"sample_rate": hparams["sample_rate"],
"split_ratio": hparams["split_ratio"],
"libritts_subsets": hparams["libritts_subsets"],
"model_name": "HiFi-GAN",
},
)
datasets = dataio_prepare(hparams)
# Initialize with four-element opt_class list
hifi_gan_brain = HifiGanBrain(
modules=hparams["modules"],
opt_class=[
hparams["opt_class_generator"],
hparams["opt_class_discriminator"],
hparams["sch_class_generator"],
hparams["sch_class_discriminator"],
],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# Train
hifi_gan_brain.fit(
hifi_gan_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["train_dataloader_opts"],
valid_loader_kwargs=hparams["valid_dataloader_opts"],
)
# Test
if "test" in datasets:
hifi_gan_brain.evaluate(
datasets["test"],
test_loader_kwargs=hparams["test_dataloader_opts"],
)
Command-Line Invocation
python train.py hparams/train.yaml --data_folder /path/to/LibriTTS
Inference
The run_inference_sample method generates audio samples during validation:
def run_inference_sample(self, name):
with torch.no_grad():
x, y = self.last_batch
# Create inference generator with weight norm removed
inference_generator = type(self.hparams.generator)(
in_channels=self.hparams.in_channels,
out_channels=self.hparams.out_channels,
resblock_type=self.hparams.resblock_type,
...
).to(self.device)
inference_generator.load_state_dict(
self.hparams.generator.state_dict()
)
inference_generator.remove_weight_norm()
# Generate waveform
sig_out = inference_generator.inference(x)
Weight normalization is removed for inference to avoid synthesis artifacts.
Key YAML Configuration
# Generator
generator: !new:speechbrain.lobes.models.HifiGAN.HifiganGenerator
in_channels: 80
out_channels: 1
resblock_type: "1"
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
resblock_kernel_sizes: [3, 7, 11]
upsample_kernel_sizes: [16, 16, 4, 4]
upsample_initial_channel: 512
upsample_factors: [8, 8, 2, 2]
# Discriminator
discriminator: !new:speechbrain.lobes.models.HifiGAN.HifiganDiscriminator
# Loss weights
generator_loss: !new:speechbrain.lobes.models.HifiGAN.GeneratorLoss
mseg_loss_weight: 1
feat_match_loss_weight: 10
l1_spec_loss_weight: 45
See Also
- Principle:Speechbrain_Speechbrain_HiFi_GAN_Vocoder_Training - Theoretical foundations of HiFi-GAN training
- Implementation:Speechbrain_Speechbrain_Prepare_Libritts - Data preparation used before vocoder training
- Implementation:Speechbrain_Speechbrain_Tacotron2_Inference_Pipeline - Inference pipeline that uses the trained vocoder