Implementation:Facebookresearch Audiocraft MultiPeriodDiscriminator
| Knowledge Sources | |
|---|---|
| Domains | Audio_Synthesis, GAN |
| Last Updated | 2026-02-14 01:00 GMT |
Overview
Concrete tool for analyzing audio waveforms at different periodic intervals to capture fine-grained harmonic patterns for adversarial training.
Description
MultiPeriodDiscriminator implements the MPD from HiFi-GAN. It reshapes 1D audio waveforms into 2D representations using different period values (default [2,3,5,7,11]) and applies 2D convolutions to each. This design captures periodic patterns at different fundamental frequencies, enabling the discriminator to assess both harmonic quality and temporal structure.
Usage
Import this class when building adversarial training pipelines for audio compression or generation models that need a period-based discriminator for quality assessment.
Code Reference
Source Location
- Repository: Facebookresearch_Audiocraft
- File: audiocraft/adversarial/discriminators/mpd.py
- Lines: 1-106
Signature
class MultiPeriodDiscriminator(MultiDiscriminator):
def __init__(self, in_channels: int = 1, out_channels: int = 1,
n_layers: int = 5, periods: tp.Optional[tp.List[int]] = None):
"""
Args:
in_channels: Number of input audio channels.
out_channels: Output channels per sub-discriminator.
n_layers: Number of convolution layers per period.
periods: List of period values (default [2,3,5,7,11]).
"""
Import
from audiocraft.adversarial.discriminators.mpd import MultiPeriodDiscriminator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Audio waveform [B, C, T] |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | list[torch.Tensor] | Discriminator logits from each period sub-discriminator |
| features | list[list[torch.Tensor]] | Intermediate features for feature matching loss |
Usage Examples
from audiocraft.adversarial.discriminators.mpd import MultiPeriodDiscriminator
import torch
mpd = MultiPeriodDiscriminator(in_channels=1)
wav = torch.randn(4, 1, 24000)
logits, features = mpd(wav)