Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft JASCO generate audio

From Leeroopedia

Overview

JASCO_generate_audio decodes continuous latents produced by the flow matching model into audio waveforms. It denormalizes the latents using stored training statistics, then passes them through the EnCodec decoder to produce the final audio output.

Implements

Principle:Facebookresearch_Audiocraft_Latent_Decoding_and_Audio_Output

API Signature

def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor

Source: audiocraft/models/jasco.py, lines 91-97

Parameters

Parameter Type Default Description
gen_latents torch.Tensor required Generated latents from flow matching, shape [B, T, D] where B is batch size, T is sequence length, D is flow dimension (typically 128)

Return Value

Returns torch.Tensor of decoded audio waveform with shape [B, C, T_audio] where C is the number of audio channels and T_audio is the audio-rate sequence length.

How It Works

from audiocraft.models import JASCO

model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M')

# Generate latents
audio, latents = model.generate_music(
    descriptions=["upbeat jazz piano"],
    chords=[("C", 0.0), ("F", 5.0)],
    return_latents=True
)

# Or decode latents separately (e.g., for post-processing)
decoded_audio = model.generate_audio(latents)

The implementation:

def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor:
    """Decode audio from generated latents"""
    assert gen_latents.dim() == 3  # [B, T, D]

    # Step 1: Denormalize latents
    gen_latents = self._unnormalized_latents(gen_latents)

    # Step 2: Decode through EnCodec decoder (note permutation to [B, D, T])
    return self.compression_model.model.decoder(gen_latents.permute(0, 2, 1))

Internal Method: _unnormalized_latents

Source: audiocraft/models/jasco.py, lines 85-89

def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor:
    """Unnormalize latents, shifting back to EnCodec's expected mean, std"""
    assert self.cfg is not None
    scaled = latents * self.cfg.compression_model_latent_std
    return scaled + self.cfg.compression_model_latent_mean

The denormalization applies:

  • Scale: Multiply by the stored standard deviation (compression_model_latent_std)
  • Shift: Add the stored mean (compression_model_latent_mean)

These statistics are stored in the model's configuration (accessed via self.cfg) and were computed from the training dataset's EnCodec latent distribution.

Tensor Shape Flow

Stage Shape Description
Flow matching output [B, T, D] Raw generated latents (e.g., [1, 500, 128])
After denormalization [B, T, D] Latents in EnCodec's native scale
After permutation [B, D, T] Transposed for decoder (channels-first)
Decoder output [B, C, T_audio] Decoded audio waveform (e.g., [1, 1, 320000] for 10s at 32kHz)

Saving Audio to Disk

After decoding, the audio_write() utility from audiocraft/data/audio.py (lines 159-231) can save the audio:

from audiocraft.data.audio import audio_write

# Generate audio
audio = model.generate_music(descriptions=["calm piano"])

# Save to disk
audio_write(
    stem_name='output/generated_music',
    wav=audio[0].cpu(),  # first sample, move to CPU
    sample_rate=model.sample_rate,
    format='wav',
    normalize=True,
    strategy='peak'
)
# Writes: output/generated_music.wav

The audio_write() function supports:

Parameter Options Description
format 'wav', 'mp3', 'ogg', 'flac' Output audio format
normalize True/False Whether to normalize audio levels
strategy 'peak', 'rms', 'clip', 'loudness' Normalization strategy
mp3_rate integer (kbps) Bitrate for MP3 encoding (default 320)

Comparison with MusicGen Decoding

Aspect MusicGen (BaseGenModel) JASCO
Input type Discrete tokens [B, K, T] Continuous latents [B, T, D]
Decoding path compression_model.decode(tokens, None) compression_model.model.decoder(latents.permute(...))
Quantization Passes through RVQ dequantizer Bypasses quantization entirely
Preprocessing None (tokens are already in discrete format) Denormalization using stored mean/std
Source genmodel.py:L262-267 jasco.py:L91-97

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment