Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft Evaluation Metrics

From Leeroopedia

Overview

This is a Wrapper Doc describing AudioCraft's evaluation metrics for generated audio quality. The implementation wraps external pretrained models and metrics libraries (TensorFlow-based FAD, PaSST classifier, LAION CLAP) into a consistent torchmetrics.Metric interface that integrates with the MusicGen evaluation pipeline.

Source Locations

Metric Source File Lines
FrechetAudioDistanceMetric audiocraft/metrics/fad.py 29-329
KLDivergenceMetric audiocraft/metrics/kld.py 53-113
PasstKLDivergenceMetric audiocraft/metrics/kld.py 116-220
CLAPTextConsistencyMetric audiocraft/metrics/clap_consistency.py 34-84

APIs

FrechetAudioDistanceMetric

class FrechetAudioDistanceMetric(torchmetrics.Metric):
    def __init__(
        self,
        bin: Union[Path, str],
        model_path: Union[Path, str],
        format: str = "wav",
        batch_size: Optional[int] = None,
        log_folder: Optional[Union[Path, str]] = None
    ):
        ...

    def update(
        self,
        preds: torch.Tensor,      # [B, C, T] generated audio
        targets: torch.Tensor,     # [B, C, T] reference audio
        sizes: torch.Tensor,       # [B] actual lengths
        sample_rates: torch.Tensor,# [B] sample rates
        stems: Optional[List[str]] = None
    ) -> None:
        ...

    def compute(self) -> float:
        """Returns FAD score."""
        ...

How it works:

  1. Audio samples are saved to disk as WAV files (resampled to 16 kHz mono for VGGish).
  2. An external TensorFlow process computes VGGish embeddings for both generated and reference sets.
  3. The Frechet distance is computed between the embedding distributions.
  4. When multiple GPUs are available, embedding computation runs in parallel.

KLDivergenceMetric / PasstKLDivergenceMetric

class KLDivergenceMetric(torchmetrics.Metric):
    def update(
        self,
        preds: torch.Tensor,       # [B, C, T] generated audio
        targets: torch.Tensor,     # [B, C, T] reference audio
        sizes: torch.Tensor,       # [B] actual lengths
        sample_rates: torch.Tensor # [B] sample rates
    ) -> None:
        ...

    def compute(self) -> dict:
        """Returns dict with 'kld', 'kld_pq', 'kld_qp', 'kld_both'."""
        ...

class PasstKLDivergenceMetric(KLDivergenceMetric):
    def __init__(self, pretrained_length: Optional[float] = None):
        ...

How it works:

  1. Audio is resampled to 32 kHz mono and split into segments matching PaSST's input requirements.
  2. The pretrained PaSST model produces class probability distributions over AudioSet labels.
  3. KL divergence is computed between generated and reference probability distributions in both directions (KLD_pq and KLD_qp).
  4. Results are accumulated and averaged across all samples.

CLAPTextConsistencyMetric

class CLAPTextConsistencyMetric(TextConsistencyMetric):
    def __init__(
        self,
        model_path: Union[str, Path],
        model_arch: str = 'HTSAT-tiny',
        enable_fusion: bool = False
    ):
        ...

    def update(
        self,
        audio: torch.Tensor,       # [B, C, T] generated audio
        text: List[str],            # [B] text descriptions
        sizes: torch.Tensor,        # [B] actual lengths
        sample_rates: torch.Tensor  # [B] sample rates
    ) -> None:
        ...

    def compute(self) -> float:
        """Returns average cosine similarity between audio and text embeddings."""
        ...

How it works:

  1. Audio is resampled to 48 kHz mono for CLAP.
  2. Audio embeddings are extracted via model.get_audio_embedding_from_data().
  3. Text embeddings are extracted via model.get_text_embedding() using a RoBERTa tokenizer.
  4. Cosine similarity between paired audio and text embeddings is accumulated and averaged.

Inputs and Outputs

Metric Inputs Output Better When
FAD Generated audio, reference audio, sizes, sample rates float (FAD score) Lower
KLD Generated audio, reference audio, sizes, sample rates dict with kld, kld_pq, kld_qp, kld_both Lower
CLAP Generated audio, text descriptions, sizes, sample rates float (cosine similarity) Higher

Integration with MusicGen Solver

In MusicGenSolver.evaluate_audio_generation():

# Instantiate metrics from config
if self.cfg.evaluate.metrics.fad:
    fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
if self.cfg.evaluate.metrics.kld:
    kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
if self.cfg.evaluate.metrics.text_consistency:
    text_consistency = builders.get_text_consistency(
        self.cfg.metrics.text_consistency).to(self.device)

# For each evaluation batch
for batch in loader:
    gen_outputs = self.run_generate_step(batch, gen_duration=target_duration)
    y_pred = gen_outputs['gen_audio']

    if fad is not None:
        fad.update(y_pred, y, sizes, sample_rates, audio_stems)
    if kldiv is not None:
        kldiv.update(y_pred, y, sizes, sample_rates)
    if text_consistency is not None:
        text_consistency.update(y_pred, texts, sizes, sample_rates)

# Compute final scores
metrics['fad'] = fad.compute()
metrics.update(kldiv.compute())
metrics['text_consistency'] = text_consistency.compute()

Configuration

From config/solver/musicgen/default.yaml:

metrics:
  fad:
    use_gt: false
    model: tf
    tf:
      bin: null  # path to frechet_audio_distance code
      model_path: //reference/fad/vggish_model.ckpt
  kld:
    use_gt: false
    model: passt
    passt:
      pretrained_length: 20
  text_consistency:
    use_gt: false
    model: clap
    clap:
      model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
      model_arch: 'HTSAT-base'
      enable_fusion: false

evaluate:
  every: 25
  metrics:
    base: false
    fad: false
    kld: false
    text_consistency: false

Dependencies

  • torchmetrics -- base metric class with distributed reduction
  • frechet_audio_distance (TensorFlow) -- external FAD computation tool
  • hear21passt -- PaSST audio classifier for KLD
  • laion_clap -- CLAP model for text consistency
  • transformers -- RoBERTa tokenizer for CLAP text processing

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment