Implementation:Facebookresearch Audiocraft JASCO generate audio
Overview
JASCO_generate_audio decodes continuous latents produced by the flow matching model into audio waveforms. It denormalizes the latents using stored training statistics, then passes them through the EnCodec decoder to produce the final audio output.
Implements
Principle:Facebookresearch_Audiocraft_Latent_Decoding_and_Audio_Output
API Signature
def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor
Source: audiocraft/models/jasco.py, lines 91-97
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
gen_latents |
torch.Tensor |
required | Generated latents from flow matching, shape [B, T, D] where B is batch size, T is sequence length, D is flow dimension (typically 128)
|
Return Value
Returns torch.Tensor of decoded audio waveform with shape [B, C, T_audio] where C is the number of audio channels and T_audio is the audio-rate sequence length.
How It Works
from audiocraft.models import JASCO
model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M')
# Generate latents
audio, latents = model.generate_music(
descriptions=["upbeat jazz piano"],
chords=[("C", 0.0), ("F", 5.0)],
return_latents=True
)
# Or decode latents separately (e.g., for post-processing)
decoded_audio = model.generate_audio(latents)
The implementation:
def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor:
"""Decode audio from generated latents"""
assert gen_latents.dim() == 3 # [B, T, D]
# Step 1: Denormalize latents
gen_latents = self._unnormalized_latents(gen_latents)
# Step 2: Decode through EnCodec decoder (note permutation to [B, D, T])
return self.compression_model.model.decoder(gen_latents.permute(0, 2, 1))
Internal Method: _unnormalized_latents
Source: audiocraft/models/jasco.py, lines 85-89
def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor:
"""Unnormalize latents, shifting back to EnCodec's expected mean, std"""
assert self.cfg is not None
scaled = latents * self.cfg.compression_model_latent_std
return scaled + self.cfg.compression_model_latent_mean
The denormalization applies:
- Scale: Multiply by the stored standard deviation (
compression_model_latent_std) - Shift: Add the stored mean (
compression_model_latent_mean)
These statistics are stored in the model's configuration (accessed via self.cfg) and were computed from the training dataset's EnCodec latent distribution.
Tensor Shape Flow
| Stage | Shape | Description |
|---|---|---|
| Flow matching output | [B, T, D] |
Raw generated latents (e.g., [1, 500, 128])
|
| After denormalization | [B, T, D] |
Latents in EnCodec's native scale |
| After permutation | [B, D, T] |
Transposed for decoder (channels-first) |
| Decoder output | [B, C, T_audio] |
Decoded audio waveform (e.g., [1, 1, 320000] for 10s at 32kHz)
|
Saving Audio to Disk
After decoding, the audio_write() utility from audiocraft/data/audio.py (lines 159-231) can save the audio:
from audiocraft.data.audio import audio_write
# Generate audio
audio = model.generate_music(descriptions=["calm piano"])
# Save to disk
audio_write(
stem_name='output/generated_music',
wav=audio[0].cpu(), # first sample, move to CPU
sample_rate=model.sample_rate,
format='wav',
normalize=True,
strategy='peak'
)
# Writes: output/generated_music.wav
The audio_write() function supports:
| Parameter | Options | Description |
|---|---|---|
format |
'wav', 'mp3', 'ogg', 'flac' |
Output audio format |
normalize |
True/False |
Whether to normalize audio levels |
strategy |
'peak', 'rms', 'clip', 'loudness' |
Normalization strategy |
mp3_rate |
integer (kbps) | Bitrate for MP3 encoding (default 320) |
Comparison with MusicGen Decoding
| Aspect | MusicGen (BaseGenModel) | JASCO |
|---|---|---|
| Input type | Discrete tokens [B, K, T] |
Continuous latents [B, T, D]
|
| Decoding path | compression_model.decode(tokens, None) |
compression_model.model.decoder(latents.permute(...))
|
| Quantization | Passes through RVQ dequantizer | Bypasses quantization entirely |
| Preprocessing | None (tokens are already in discrete format) | Denormalization using stored mean/std |
| Source | genmodel.py:L262-267 |
jasco.py:L91-97
|