Implementation:Speechbrain Speechbrain MetricGanBrain Fit Batch
Appearance
| Property | Value |
|---|---|
| Implementation Name | MetricGanBrain_Fit_Batch |
| API | MetricGanBrain.fit_batch(self, batch)
|
| Source File | recipes/Voicebank/enhance/MetricGAN/train.py -- Class: L48, fit_batch: L299-346, compute_objectives: L83-163
|
| Import | Recipe-specific Brain subclass (not importable as library) |
| Type | API Doc |
| Workflow | Speech_Enhancement_Training |
| Domains | GAN_Training, Speech_Enhancement |
| Related Principle | Principle:Speechbrain_Speechbrain_GAN_Based_Enhancement_Training |
Purpose
MetricGanBrain is a custom sb.Brain subclass that implements the MetricGAN+ training procedure for speech enhancement. The core method fit_batch() manages the alternating generator/discriminator optimization, while compute_objectives() implements the sub-stage-specific loss computation that uses actual PESQ scores as discriminator training targets.
Class Definition
class SubStage(Enum):
"""For keeping track of training stage progress"""
GENERATOR = auto()
CURRENT = auto()
HISTORICAL = auto()
class MetricGanBrain(sb.Brain):
"""Brain class for MetricGAN+ speech enhancement training.
Manages dual optimizers, sub-stage training, and historical
sample replay for adversarial perceptual metric optimization.
"""
fit_batch Method
def fit_batch(self, batch):
"Compute gradients and update either D or G based on sub-stage."
predictions = self.compute_forward(batch, sb.Stage.TRAIN)
loss_tracker = 0
if self.sub_stage == SubStage.CURRENT:
# Discriminator training on current data: clean, enhanced, noisy
for mode in ["clean", "enh", "noisy"]:
loss = self.compute_objectives(
predictions, batch, sb.Stage.TRAIN, f"D_{mode}"
)
self.d_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.modules.parameters(), self.max_grad_norm
)
self.d_optimizer.step()
loss_tracker += loss.detach() / 3
elif self.sub_stage == SubStage.HISTORICAL:
# Discriminator training on historical enhanced samples
loss = self.compute_objectives(
predictions, batch, sb.Stage.TRAIN, "D_enh"
)
self.d_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.modules.parameters(), self.max_grad_norm
)
self.d_optimizer.step()
loss_tracker += loss.detach()
elif self.sub_stage == SubStage.GENERATOR:
# Clamp learnable sigmoid to prevent gradient explosion
for name, param in self.modules.generator.named_parameters():
if "Learnable_sigmoid" in name:
param.data = torch.clamp(param, max=3.5)
param.data[param != param] = 3.5 # set NaN to 3.5
loss = self.compute_objectives(
predictions, batch, sb.Stage.TRAIN, "generator"
)
self.g_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.modules.parameters(), self.max_grad_norm
)
self.g_optimizer.step()
loss_tracker += loss.detach()
return loss_tracker
compute_objectives Method
The compute_objectives method dispatches on the optim_name parameter to compute the appropriate loss for each sub-stage:
def compute_objectives(self, predictions, batch, stage, optim_name=""):
"Given the network predictions and targets compute the total loss"
predict_wav = predictions
predict_spec = self.compute_feats(predict_wav)
clean_wav, lens = batch.clean_sig
clean_spec = self.compute_feats(clean_wav)
if optim_name == "generator":
# Generator aims for score of 1.0 (perfect quality)
target_score = torch.ones(self.batch_size, 1, device=self.device)
est_score = self.est_score(predict_spec, clean_spec)
mse_cost = self.hparams.compute_cost(predict_spec, clean_spec, lens)
# cost = MSE(est_score, 1.0) + mse_weight * MSE(spec, clean_spec)
elif optim_name == "D_clean":
# Discriminator learns: clean speech -> score 1.0
target_score = torch.ones(self.batch_size, 1, device=self.device)
est_score = self.est_score(clean_spec, clean_spec)
elif optim_name == "D_enh" and self.sub_stage == SubStage.CURRENT:
# Discriminator learns: enhanced speech -> actual PESQ score
target_score = self.score(ids, predict_wav, clean_wav, lens)
est_score = self.est_score(predict_spec, clean_spec)
elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL:
# Discriminator relearns: historical enhanced speech -> saved score
target_score = batch.score.unsqueeze(1).float()
est_score = self.est_score(predict_spec, clean_spec)
elif optim_name == "D_noisy":
# Discriminator learns: noisy speech -> actual PESQ score
noisy_wav, _ = batch.noisy_sig
noisy_spec = self.compute_feats(noisy_wav)
target_score = self.score(ids, noisy_wav, clean_wav, lens)
est_score = self.est_score(noisy_spec, clean_spec)
cost = self.hparams.compute_cost(est_score, target_score)
if optim_name == "generator":
cost += self.hparams.mse_weight * mse_cost
return cost
compute_forward Method
def compute_forward(self, batch, stage):
"Given an input batch computes the enhanced signal"
batch = batch.to(self.device)
if self.sub_stage == SubStage.HISTORICAL:
# Historical data already has pre-computed enhanced wavs
predict_wav, lens = batch.enh_sig
else:
noisy_wav, lens = batch.noisy_sig
noisy_spec = self.compute_feats(noisy_wav)
# Predict spectral mask via generator
mask = self.modules.generator(noisy_spec, lengths=lens)
mask = mask.clamp(min=self.hparams.min_mask).squeeze(2)
predict_spec = torch.mul(mask, noisy_spec)
# Reconstruct waveform via ISTFT
predict_wav = self.hparams.resynth(
torch.expm1(predict_spec), noisy_wav
)
return predict_wav
Training Orchestration
The epoch-level training is managed by on_stage_start and train_discriminator:
def train_discriminator(self):
"""A total of 3 data passes to update discriminator."""
# Pass 1: Current data (clean, enhanced, noisy)
self.sub_stage = SubStage.CURRENT
self.fit(range(1), self.train_set,
train_loader_kwargs=self.hparams.dataloader_options)
# Pass 2: Historical enhanced data
if self.historical_set:
self.sub_stage = SubStage.HISTORICAL
self.fit(range(1), self.historical_set,
train_loader_kwargs=self.hparams.dataloader_options)
# Pass 3: Current data again
self.sub_stage = SubStage.CURRENT
self.fit(range(1), self.train_set,
train_loader_kwargs=self.hparams.dataloader_options)
Dual Optimizer Initialization
def init_optimizers(self):
"Initializes the generator and discriminator optimizers"
self.g_optimizer = self.hparams.g_opt_class(
self.modules.generator.parameters()
)
self.d_optimizer = self.hparams.d_opt_class(
self.modules.discriminator.parameters()
)
if self.checkpointer is not None:
self.checkpointer.add_recoverable("g_opt", self.g_optimizer)
self.checkpointer.add_recoverable("d_opt", self.d_optimizer)
Usage Example
Full Training Pipeline
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from train import MetricGanBrain, SubStage
# Load configuration
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Prepare data
from voicebank_prepare import prepare_voicebank
prepare_voicebank(
data_folder=hparams["data_folder"],
save_folder=hparams["data_folder"],
)
# Create datasets
datasets = dataio_prep(hparams)
# Initialize Brain
se_brain = MetricGanBrain(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
se_brain.train_set = datasets["train"]
se_brain.historical_set = {}
se_brain.noisy_scores = {}
se_brain.batch_size = hparams["dataloader_options"]["batch_size"]
se_brain.sub_stage = SubStage.GENERATOR
# Train
se_brain.fit(
epoch_counter=hparams["epoch_counter"],
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["valid_dataloader_options"],
)
# Evaluate
test_stats = se_brain.evaluate(
test_set=datasets["test"],
max_key=hparams["target_metric"],
test_loader_kwargs=hparams["dataloader_options"],
)
Key Configuration Parameters
| Parameter | Default | Description |
|---|---|---|
target_metric |
"pesq" |
Metric used as discriminator target ("pesq" or "stoi")
|
G_lr |
0.0005 | Generator learning rate |
D_lr |
0.0005 | Discriminator learning rate |
mse_weight |
0 | Weight for spectral MSE reconstruction loss in generator objective |
min_mask |
0.05 | Minimum mask value (prevents complete signal suppression) |
number_of_epochs |
750 | Total training epochs |
number_of_samples |
100 | Samples per epoch for generator training |
history_portion |
0.2 | Fraction of historical set used per epoch |
train_N_batch |
1 | Batch size for training |
valid_N_batch |
20 | Batch size for validation |
Inputs and Outputs
Inputs (per batch):
batch.noisy_sig: Noisy speech waveform tensor and lengthsbatch.clean_sig: Clean speech reference waveform tensor and lengthsbatch.noisy_wav: Path to noisy audio file (for identification)
Outputs:
- Generator loss: MSE(D(enhanced, clean), 1.0) + mse_weight * spectral_MSE
- Discriminator loss: MSE(D(input, clean), actual_score) for each input type
- Enhanced wavs: Written to
enhanced_folderduring validation/test
See Also
- Principle:Speechbrain_Speechbrain_GAN_Based_Enhancement_Training -- The theoretical basis for MetricGAN+ training
- Implementation:Speechbrain_Speechbrain_Load_Hyperpyyaml_Enhancement -- How the MetricGAN architecture is configured
- Implementation:Speechbrain_Speechbrain_Composite_Eval_Metrics -- Metrics used to evaluate enhancement quality
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment