Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft JASCO generate music

From Leeroopedia

Overview

JASCO_generate_music is the main generation entry point for JASCO, orchestrating the full pipeline from text descriptions and temporal conditions through flow matching ODE integration to produce audio latents.

Implements

Principle:Facebookresearch_Audiocraft_Flow_Matching_Generation

API Signature

@torch.no_grad()
def generate_music(
    self,
    descriptions: List[str],
    drums_wav: Optional[torch.Tensor] = None,
    drums_sample_rate: int = 32000,
    chords: Optional[List[Tuple[str, float]]] = None,
    melody_salience_matrix: Optional[torch.Tensor] = None,
    iopaint_wav: Optional[torch.Tensor] = None,
    segment_duration: float = 10.0,
    frame_rate: float = 50.0,
    melody_bins: int = 53,
    progress: bool = False,
    return_latents: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

Source: audiocraft/models/jasco.py, lines 269-315

Parameters

Parameter Type Default Description
descriptions List[str] required Text descriptions for conditioning, one per sample in the batch
drums_wav Optional[Tensor] None Drum audio waveform [B, C, T] or [C, T] for conditioning
drums_sample_rate int 32000 Sample rate of the provided drum audio
chords Optional[List[Tuple[str, float]]] None Chord progression as (label, start_time) tuples
melody_salience_matrix Optional[Tensor] None Melody salience matrix [B, 53, T]
iopaint_wav Optional[Tensor] None In/out-painting waveform (reserved for future use)
segment_duration float 10.0 Segment duration the model was trained on (seconds)
frame_rate float 50.0 Frame rate the model was trained on
melody_bins int 53 Number of melody pitch bins
progress bool False Display ODE integration progress
return_latents bool False If True, return both audio and raw latents

Return Value

  • When return_latents=False: torch.Tensor of decoded audio waveform [B, C, T]
  • When return_latents=True: Tuple of (audio_tensor, latent_tensor)

Dependencies

  • torch -- for tensor operations
  • torchdiffeq -- for ODE integration via odeint()

How It Works

from audiocraft.models import JASCO

model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M')
model.set_generation_params(cfg_coef_all=5.0)

# Generate with text + chords
audio = model.generate_music(
    descriptions=["upbeat jazz piano"],
    chords=[("Cmaj7", 0.0), ("Dm7", 2.5), ("G7", 5.0), ("Cmaj7", 7.5)],
    progress=True
)

# Generate with text + drums
import torchaudio
drums, sr = torchaudio.load("drums_loop.wav")
audio = model.generate_music(
    descriptions=["electronic dance music"],
    drums_wav=drums.unsqueeze(0),
    drums_sample_rate=sr
)

# Generate with text only
audio = model.generate_music(descriptions=["calm acoustic guitar melody"])

# Get both audio and latents
audio, latents = model.generate_music(
    descriptions=["orchestral film score"],
    chords=[("Am", 0.0), ("F", 5.0)],
    return_latents=True
)

Internal Pipeline

The implementation orchestrates the full generation flow:

@torch.no_grad()
def generate_music(self, descriptions, drums_wav=None, drums_sample_rate=32000,
                   chords=None, melody_salience_matrix=None, iopaint_wav=None,
                   segment_duration=10.0, frame_rate=50.0, melody_bins=53,
                   progress=False, return_latents=False):
    # Step 1: Convert drum audio sample rate if provided
    if drums_wav is not None:
        if drums_wav.dim() == 2:
            drums_wav = drums_wav[None]
        drums_wav = convert_audio(drums_wav, drums_sample_rate,
                                   self.sample_rate, self.audio_channels)

    # Step 2: Prepare text conditioning attributes
    cond_attributes, prompt_tokens = self._prepare_tokens_and_attributes(
        descriptions=descriptions, prompt=None)

    # Step 3: Prepare temporal conditions (chords, drums, melody)
    jasco_attributes = self._prepare_temporal_conditions(
        attributes=cond_attributes,
        expected_length=int(segment_duration * frame_rate),
        chords=chords, drums_wav=drums_wav,
        salience_matrix=melody_salience_matrix, melody_bins=melody_bins)

    # Step 4: Generate latents via flow matching
    tokens = self._generate_tokens(jasco_attributes, prompt_tokens, progress)

    # Step 5: Decode latents to audio
    if return_latents:
        return self.generate_audio(tokens), tokens
    return self.generate_audio(tokens)

FlowMatchingModel.generate()

The core ODE integration happens in FlowMatchingModel.generate() at audiocraft/models/flow_matching.py, lines 419-516:

@torch.no_grad()
def generate(self, prompt=None, conditions=[], num_samples=None,
             max_gen_len=256, callback=None,
             cfg_coef_all=3.0, cfg_coef_txt=1.0,
             euler=False, euler_steps=100,
             ode_rtol=1e-5, ode_atol=1e-5):
    # Setup
    B, T, D = num_samples, max_gen_len, self.flow_dim
    z_0 = torch.randn((B, T, D), device=device)  # noise prior

    # Preprocess multi-source CFG
    condition_tensors, cfg_terms = self._multi_source_cfg_preprocess(
        conditions, cfg_coef_all, cfg_coef_txt)

    if euler:
        # Fixed-step Euler integration
        dt = 1 / euler_steps
        z = z_0
        t = torch.zeros((1,), device=device)
        for _ in range(euler_steps):
            v_theta = self.estimated_vector_field(
                z, t, condition_tensors=condition_tensors, cfg_terms=cfg_terms)
            z = z + dt * v_theta
            t = t + dt
        z_1 = z
    else:
        # Adaptive ODE solver (dopri5)
        t = torch.tensor([0, 1.0 - 1e-5], device=device)
        z = odeint(inner_ode_func, z_0, t,
                   atol=ode_atol, rtol=ode_rtol)
        z_1 = z[-1]

    return z_1  # [B, T, D]

Generation Parameters Reference

Parameter Default Effect
cfg_coef_all 5.0 Higher values increase overall condition adherence
cfg_coef_txt 0.0 Non-zero values add text-specific guidance
euler False Switch from adaptive to fixed-step integration
euler_steps 100 Number of Euler steps (only when euler=True)
ode_rtol 1e-5 Relative tolerance for adaptive solver
ode_atol 1e-5 Absolute tolerance for adaptive solver

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment