Implementation:Facebookresearch Audiocraft MultiScaleDiscriminator
| Knowledge Sources | |
|---|---|
| Domains | Audio_Synthesis, GAN |
| Last Updated | 2026-02-14 01:00 GMT |
Overview
Concrete tool for evaluating audio waveforms at multiple temporal resolutions using average pooling and 1D convolutions for adversarial training.
Description
MultiScaleDiscriminator implements the MSD architecture that evaluates audio at multiple temporal scales. It applies average pooling with increasing factors to create downsampled views of the audio, then runs 1D convolutional sub-discriminators on each scale. This captures both fine-grained and coarse temporal patterns.
Usage
Import this class when building adversarial training pipelines for audio compression or generation models that need a multi-scale temporal discriminator.
Code Reference
Source Location
- Repository: Facebookresearch_Audiocraft
- File: audiocraft/adversarial/discriminators/msd.py
- Lines: 1-126
Signature
class MultiScaleDiscriminator(MultiDiscriminator):
def __init__(self, in_channels: int = 1, out_channels: int = 1,
n_layers: int = 4, downsample_factor: int = 2,
n_scales: int = 3):
"""
Args:
in_channels: Number of input audio channels.
out_channels: Output channels per sub-discriminator.
n_layers: Number of convolution layers per scale.
downsample_factor: Pooling factor between scales.
n_scales: Number of temporal scales.
"""
Import
from audiocraft.adversarial.discriminators.msd import MultiScaleDiscriminator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Audio waveform [B, C, T] |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | list[torch.Tensor] | Discriminator logits from each scale |
| features | list[list[torch.Tensor]] | Intermediate features for feature matching loss |
Usage Examples
from audiocraft.adversarial.discriminators.msd import MultiScaleDiscriminator
import torch
msd = MultiScaleDiscriminator(in_channels=1, n_scales=3)
wav = torch.randn(4, 1, 24000)
logits, features = msd(wav)