Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain HifiGanBrain Fit Batch

From Leeroopedia


Property Value
Type API Doc
Repository speechbrain/speechbrain
Source File recipes/LibriTTS/vocoder/hifigan/train.py:L25 (class), L26-58 (compute_forward), L59-81 (compute_objectives), L82-111 (fit_batch)
Import Recipe-specific Brain subclass (not directly importable as a library)
Related Principle Principle:Speechbrain_Speechbrain_HiFi_GAN_Vocoder_Training

Class Definition

class HifiGanBrain(sb.Brain):
    """Brain class for HiFi-GAN vocoder training with adversarial loss"""

HifiGanBrain extends SpeechBrain's Brain class to implement the GAN training loop, with separate generator and discriminator optimization steps within each batch.

API Signatures

compute_forward

def compute_forward(self, batch, stage):
    """Generates synthesized waveforms and computes discriminator scores.

    Arguments
    ---------
    batch : PaddedBatch
        a single batch
    stage : speechbrain.Stage
        the training stage

    Returns
    -------
    y_g_hat : torch.Tensor
        Generated waveform
    scores_fake : torch.Tensor
        Discriminator scores for generated audio
    feats_fake : torch.Tensor
        Discriminator intermediate features for generated audio
    scores_real : torch.Tensor
        Discriminator scores for real audio
    feats_real : torch.Tensor
        Discriminator intermediate features for real audio
    """

compute_objectives

def compute_objectives(self, predictions, batch, stage):
    """Computes and combines generator and discriminator losses.

    Returns
    -------
    loss : dict
        Dictionary containing 'G_loss' and 'D_loss' keys
    """

fit_batch

def fit_batch(self, batch):
    """Train discriminator and generator adversarially.

    Returns
    -------
    loss : torch.Tensor
        Generator loss (detached)
    """

Description

The HifiGanBrain class implements the complete adversarial training loop for the HiFi-GAN vocoder. The key distinction from standard SpeechBrain Brain subclasses is the alternating generator/discriminator training within fit_batch, with separate optimizers, loss functions, and learning rate schedulers for each network.

Forward Pass

The compute_forward method performs three operations:

def compute_forward(self, batch, stage):
    batch = batch.to(self.device)
    x, _ = batch.mel          # Mel-spectrogram input
    y, _ = batch.sig          # Ground truth waveform

    # Generate synthesized waveform from mel
    y_g_hat = self.modules.generator(x)[:, :, :y.size(2)]

    # Score both real and fake with discriminator
    scores_fake, feats_fake = self.modules.discriminator(y_g_hat.detach())
    scores_real, feats_real = self.modules.discriminator(y)

    return (y_g_hat, scores_fake, feats_fake, scores_real, feats_real)

The generated waveform is truncated to match the target length via [:, :, :y.size(2)].

Adversarial Training Loop

The fit_batch method implements the alternating optimization strategy:

def fit_batch(self, batch):
    batch = batch.to(self.device)
    y, _ = batch.sig

    # Forward pass
    outputs = self.compute_forward(batch, sb.core.Stage.TRAIN)
    (y_g_hat, scores_fake, feats_fake, scores_real, feats_real) = outputs

    # --- Step 1: Train Discriminator ---
    loss_d = self.compute_objectives(outputs, batch, sb.core.Stage.TRAIN)["D_loss"]
    self.optimizer_d.zero_grad()
    loss_d.backward()
    self.optimizer_d.step()

    # --- Step 2: Train Generator ---
    # Re-score with updated discriminator
    scores_fake, feats_fake = self.modules.discriminator(y_g_hat)
    scores_real, feats_real = self.modules.discriminator(y)
    outputs = (y_g_hat, scores_fake, feats_fake, scores_real, feats_real)
    loss_g = self.compute_objectives(outputs, batch, sb.core.Stage.TRAIN)["G_loss"]
    self.optimizer_g.zero_grad()
    loss_g.backward()
    self.optimizer_g.step()

    return loss_g.detach().cpu()

Key detail: After updating the discriminator, the generator's output is re-scored with the updated discriminator before computing the generator loss. This ensures the generator trains against the most up-to-date discriminator.

Loss Functions

Generator Loss

The generator loss combines multiple components configured in the YAML:

loss_g = self.hparams.generator_loss(
    stage, y_hat, y, scores_fake, feats_fake, feats_real
)

Components and their weights:

  • L1 mel-spectrogram loss (l1_spec_loss): Weight 45 - L1 distance between mel-spectrograms of real and generated audio
  • Feature matching loss (feat_match_loss): Weight 10 - L1 distance between discriminator intermediate features
  • MSE adversarial loss (mseg_loss): Weight 1 - MSE between discriminator scores for generated audio and target 1.0

Discriminator Loss

loss_d = self.hparams.discriminator_loss(scores_fake, scores_real)
  • MSE discriminator loss (msed_loss): MSE between real scores (target: 1) and fake scores (target: 0)

Dual Optimizer Initialization

The init_optimizers method sets up separate optimizers and schedulers:

def init_optimizers(self):
    (opt_g_class, opt_d_class, sch_g_class, sch_d_class) = self.opt_class

    self.optimizer_g = opt_g_class(self.modules.generator.parameters())
    self.optimizer_d = opt_d_class(self.modules.discriminator.parameters())
    self.scheduler_g = sch_g_class(self.optimizer_g)
    self.scheduler_d = sch_d_class(self.optimizer_d)

Both use AdamW with learning rate 0.0002 and ExponentialLR decay (gamma=0.9999).

Batch Structure

The batch contains two dynamic items provided by the data pipeline:

@sb.utils.data_pipeline.takes("wav", "segment")
@sb.utils.data_pipeline.provides("mel", "sig")
def audio_pipeline(wav, segment):
    audio = sb.dataio.dataio.read_audio(wav)
    audio = torch.FloatTensor(audio).unsqueeze(0)

    # Random segment extraction for training
    if segment:
        if audio.size(1) >= segment_size:
            max_audio_start = audio.size(1) - segment_size
            audio_start = torch.randint(0, max_audio_start, (1,))
            audio = audio[:, audio_start:audio_start + segment_size]
        else:
            audio = torch.nn.functional.pad(
                audio, (0, segment_size - audio.size(1)), "constant"
            )

    mel = hparams["mel_spectogram"](audio=audio.squeeze(0))
    return mel, audio
  • mel: Mel-spectrogram computed on-the-fly from the audio segment
  • sig: Raw waveform tensor (8192 samples for training segments)

Usage Example

Complete Training Script

import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
    hparams = load_hyperpyyaml(fin, overrides)

sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)

# Data preparation
from libritts_prepare import prepare_libritts
sb.utils.distributed.run_on_main(
    prepare_libritts,
    kwargs={
        "data_folder": hparams["data_folder"],
        "save_json_train": hparams["train_json"],
        "save_json_valid": hparams["valid_json"],
        "save_json_test": hparams["test_json"],
        "sample_rate": hparams["sample_rate"],
        "split_ratio": hparams["split_ratio"],
        "libritts_subsets": hparams["libritts_subsets"],
        "model_name": "HiFi-GAN",
    },
)

datasets = dataio_prepare(hparams)

# Initialize with four-element opt_class list
hifi_gan_brain = HifiGanBrain(
    modules=hparams["modules"],
    opt_class=[
        hparams["opt_class_generator"],
        hparams["opt_class_discriminator"],
        hparams["sch_class_generator"],
        hparams["sch_class_discriminator"],
    ],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Train
hifi_gan_brain.fit(
    hifi_gan_brain.hparams.epoch_counter,
    train_set=datasets["train"],
    valid_set=datasets["valid"],
    train_loader_kwargs=hparams["train_dataloader_opts"],
    valid_loader_kwargs=hparams["valid_dataloader_opts"],
)

# Test
if "test" in datasets:
    hifi_gan_brain.evaluate(
        datasets["test"],
        test_loader_kwargs=hparams["test_dataloader_opts"],
    )

Command-Line Invocation

python train.py hparams/train.yaml --data_folder /path/to/LibriTTS

Inference

The run_inference_sample method generates audio samples during validation:

def run_inference_sample(self, name):
    with torch.no_grad():
        x, y = self.last_batch

        # Create inference generator with weight norm removed
        inference_generator = type(self.hparams.generator)(
            in_channels=self.hparams.in_channels,
            out_channels=self.hparams.out_channels,
            resblock_type=self.hparams.resblock_type,
            ...
        ).to(self.device)
        inference_generator.load_state_dict(
            self.hparams.generator.state_dict()
        )
        inference_generator.remove_weight_norm()

        # Generate waveform
        sig_out = inference_generator.inference(x)

Weight normalization is removed for inference to avoid synthesis artifacts.

Key YAML Configuration

# Generator
generator: !new:speechbrain.lobes.models.HifiGAN.HifiganGenerator
  in_channels: 80
  out_channels: 1
  resblock_type: "1"
  resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
  resblock_kernel_sizes: [3, 7, 11]
  upsample_kernel_sizes: [16, 16, 4, 4]
  upsample_initial_channel: 512
  upsample_factors: [8, 8, 2, 2]

# Discriminator
discriminator: !new:speechbrain.lobes.models.HifiGAN.HifiganDiscriminator

# Loss weights
generator_loss: !new:speechbrain.lobes.models.HifiGAN.GeneratorLoss
  mseg_loss_weight: 1
  feat_match_loss_weight: 10
  l1_spec_loss_weight: 45

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment