Implementation:Facebookresearch Audiocraft Evaluation Metrics
Appearance
Overview
This is a Wrapper Doc describing AudioCraft's evaluation metrics for generated audio quality. The implementation wraps external pretrained models and metrics libraries (TensorFlow-based FAD, PaSST classifier, LAION CLAP) into a consistent torchmetrics.Metric interface that integrates with the MusicGen evaluation pipeline.
Source Locations
| Metric | Source File | Lines |
|---|---|---|
FrechetAudioDistanceMetric |
audiocraft/metrics/fad.py |
29-329 |
KLDivergenceMetric |
audiocraft/metrics/kld.py |
53-113 |
PasstKLDivergenceMetric |
audiocraft/metrics/kld.py |
116-220 |
CLAPTextConsistencyMetric |
audiocraft/metrics/clap_consistency.py |
34-84 |
APIs
FrechetAudioDistanceMetric
class FrechetAudioDistanceMetric(torchmetrics.Metric):
def __init__(
self,
bin: Union[Path, str],
model_path: Union[Path, str],
format: str = "wav",
batch_size: Optional[int] = None,
log_folder: Optional[Union[Path, str]] = None
):
...
def update(
self,
preds: torch.Tensor, # [B, C, T] generated audio
targets: torch.Tensor, # [B, C, T] reference audio
sizes: torch.Tensor, # [B] actual lengths
sample_rates: torch.Tensor,# [B] sample rates
stems: Optional[List[str]] = None
) -> None:
...
def compute(self) -> float:
"""Returns FAD score."""
...
How it works:
- Audio samples are saved to disk as WAV files (resampled to 16 kHz mono for VGGish).
- An external TensorFlow process computes VGGish embeddings for both generated and reference sets.
- The Frechet distance is computed between the embedding distributions.
- When multiple GPUs are available, embedding computation runs in parallel.
KLDivergenceMetric / PasstKLDivergenceMetric
class KLDivergenceMetric(torchmetrics.Metric):
def update(
self,
preds: torch.Tensor, # [B, C, T] generated audio
targets: torch.Tensor, # [B, C, T] reference audio
sizes: torch.Tensor, # [B] actual lengths
sample_rates: torch.Tensor # [B] sample rates
) -> None:
...
def compute(self) -> dict:
"""Returns dict with 'kld', 'kld_pq', 'kld_qp', 'kld_both'."""
...
class PasstKLDivergenceMetric(KLDivergenceMetric):
def __init__(self, pretrained_length: Optional[float] = None):
...
How it works:
- Audio is resampled to 32 kHz mono and split into segments matching PaSST's input requirements.
- The pretrained PaSST model produces class probability distributions over AudioSet labels.
- KL divergence is computed between generated and reference probability distributions in both directions (KLD_pq and KLD_qp).
- Results are accumulated and averaged across all samples.
CLAPTextConsistencyMetric
class CLAPTextConsistencyMetric(TextConsistencyMetric):
def __init__(
self,
model_path: Union[str, Path],
model_arch: str = 'HTSAT-tiny',
enable_fusion: bool = False
):
...
def update(
self,
audio: torch.Tensor, # [B, C, T] generated audio
text: List[str], # [B] text descriptions
sizes: torch.Tensor, # [B] actual lengths
sample_rates: torch.Tensor # [B] sample rates
) -> None:
...
def compute(self) -> float:
"""Returns average cosine similarity between audio and text embeddings."""
...
How it works:
- Audio is resampled to 48 kHz mono for CLAP.
- Audio embeddings are extracted via
model.get_audio_embedding_from_data(). - Text embeddings are extracted via
model.get_text_embedding()using a RoBERTa tokenizer. - Cosine similarity between paired audio and text embeddings is accumulated and averaged.
Inputs and Outputs
| Metric | Inputs | Output | Better When |
|---|---|---|---|
| FAD | Generated audio, reference audio, sizes, sample rates | float (FAD score) |
Lower |
| KLD | Generated audio, reference audio, sizes, sample rates | dict with kld, kld_pq, kld_qp, kld_both |
Lower |
| CLAP | Generated audio, text descriptions, sizes, sample rates | float (cosine similarity) |
Higher |
Integration with MusicGen Solver
In MusicGenSolver.evaluate_audio_generation():
# Instantiate metrics from config
if self.cfg.evaluate.metrics.fad:
fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
if self.cfg.evaluate.metrics.kld:
kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
if self.cfg.evaluate.metrics.text_consistency:
text_consistency = builders.get_text_consistency(
self.cfg.metrics.text_consistency).to(self.device)
# For each evaluation batch
for batch in loader:
gen_outputs = self.run_generate_step(batch, gen_duration=target_duration)
y_pred = gen_outputs['gen_audio']
if fad is not None:
fad.update(y_pred, y, sizes, sample_rates, audio_stems)
if kldiv is not None:
kldiv.update(y_pred, y, sizes, sample_rates)
if text_consistency is not None:
text_consistency.update(y_pred, texts, sizes, sample_rates)
# Compute final scores
metrics['fad'] = fad.compute()
metrics.update(kldiv.compute())
metrics['text_consistency'] = text_consistency.compute()
Configuration
From config/solver/musicgen/default.yaml:
metrics:
fad:
use_gt: false
model: tf
tf:
bin: null # path to frechet_audio_distance code
model_path: //reference/fad/vggish_model.ckpt
kld:
use_gt: false
model: passt
passt:
pretrained_length: 20
text_consistency:
use_gt: false
model: clap
clap:
model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
model_arch: 'HTSAT-base'
enable_fusion: false
evaluate:
every: 25
metrics:
base: false
fad: false
kld: false
text_consistency: false
Dependencies
torchmetrics-- base metric class with distributed reductionfrechet_audio_distance(TensorFlow) -- external FAD computation toolhear21passt-- PaSST audio classifier for KLDlaion_clap-- CLAP model for text consistencytransformers-- RoBERTa tokenizer for CLAP text processing
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment