Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft JASCO prepare temporal conditions

From Leeroopedia

Overview

JASCO_prepare_temporal_conditions orchestrates the preparation of all temporal conditioning inputs (chords, drums, melody) by delegating to specialized internal methods. It transforms raw user inputs into ConditioningAttributes objects that the flow matching model consumes during generation.

Implements

Principle:Facebookresearch_Audiocraft_Temporal_Conditioning_Preparation

API Signature

@torch.no_grad()
def _prepare_temporal_conditions(
    self,
    attributes: List[ConditioningAttributes],
    expected_length: int,
    chords: Optional[List[Tuple[str, float]]],
    drums_wav: Optional[torch.Tensor],
    salience_matrix: Optional[torch.Tensor],
    melody_bins: int = 53,
) -> List[ConditioningAttributes]

Source: audiocraft/models/jasco.py, lines 239-266

Parameters

Parameter Type Default Description
attributes List[ConditioningAttributes] required Pre-constructed attributes containing text descriptions
expected_length int required Expected number of generated frames (segment_duration * frame_rate)
chords Optional[List[Tuple[str, float]]] required Chord progression as (label, start_time) tuples, e.g., [("C", 0.0), ("F", 4.0)]
drums_wav Optional[torch.Tensor] required Drum audio waveform of shape [B, C, T]
salience_matrix Optional[torch.Tensor] required Melody salience matrix of shape [B, melody_bins, T]
melody_bins int 53 Number of pitch bins in the salience matrix

Return Value

Returns List[ConditioningAttributes] with symbolic and wav fields populated for all temporal conditions.

How It Works

The method delegates to three internal methods in sequence:

@torch.no_grad()
def _prepare_temporal_conditions(self, attributes, expected_length,
                                  chords, drums_wav, salience_matrix,
                                  melody_bins=53):
    attributes = self._prepare_chord_conditions(attributes=attributes, chords=chords)
    attributes = self._prepare_drums_conditions(attributes=attributes, drums_wav=drums_wav)
    attributes = self._prepare_melody_conditions(attributes=attributes, melody=salience_matrix,
                                                  expected_length=expected_length,
                                                  melody_bins=melody_bins)
    return attributes

Internal Method: _prepare_chord_conditions

Source: audiocraft/models/jasco.py, lines 137-173

Converts symbolic chord progressions into per-frame integer sequences:

def _prepare_chord_conditions(self, attributes, chords):
    if chords is None or chords == []:
        for att in attributes:
            att.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
                frame_chords=-1 * torch.ones(1, dtype=torch.int32))
        return attributes

    # Flip from (chord, start_time) to (start_time, chord)
    chords_time_first = [(item[1], item[0]) for item in chords]

    # Convert to per-frame integer sequence using chord vocabulary
    frame_chords = construct_frame_chords(
        min_timestamp=0,
        chord_changes=chords_time_first,
        mapping_dict=self.chords_mapping,
        prev_chord='',
        frame_rate=self.compression_model.frame_rate,
        segment_duration=self.duration)

    for att in attributes:
        att.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
            frame_chords=torch.tensor(frame_chords))
    return attributes

Internal Method: _prepare_drums_conditions

Source: audiocraft/models/jasco.py, lines 175-211

Wraps drum audio into a WavCondition with proper padding/trimming:

@torch.no_grad()
def _prepare_drums_conditions(self, attributes, drums_wav):
    for attr in attributes:
        if drums_wav is None:
            # Null condition: zero tensor
            attr.wav[JascoCondConst.DRM.value] = WavCondition(
                torch.zeros((1, 1, 1), device=self.device),
                torch.tensor([0], device=self.device),
                sample_rate=[self.sample_rate], path=[None])
        else:
            expected_length = self.lm.cfg.dataset.segment_duration * self.sample_rate
            drums_wav = drums_wav[..., :expected_length]  # trim
            if drums_wav.shape[-1] < expected_length:     # pad
                diff = expected_length - drums_wav.shape[-1]
                drums_wav = torch.cat((drums_wav,
                    torch.zeros((..., diff), device=drums_wav.device)), dim=-1)
            attr.wav[JascoCondConst.DRM.value] = WavCondition(
                drums_wav.to(device=self.device),
                torch.tensor([drums_wav.shape[-1]], device=self.device),
                sample_rate=[self.sample_rate], path=[None])
    return attributes

Internal Method: _prepare_melody_conditions

Source: audiocraft/models/jasco.py, lines 213-237

Wraps the pre-computed salience matrix into a SymbolicCondition:

@torch.no_grad()
def _prepare_melody_conditions(self, attributes, melody, expected_length, melody_bins=53):
    for attr in attributes:
        if melody is None:
            melody = torch.zeros((melody_bins, expected_length))
        attr.symbolic[JascoCondConst.MLD.value] = SymbolicCondition(melody=melody)
    return attributes

JASCO Conditioner Classes

The prepared conditions are later processed by these conditioner classes in audiocraft/modules/jasco_conditioners.py:

Class Source Lines Input Output
ChordsEmbConditioner(card, out_dim, device) L36-56 SymbolicCondition(frame_chords) Embedded chord vectors via nn.Embedding
DrumsConditioner(out_dim, sample_rate, blurring_factor=3) L59-214 WavCondition with drum audio Blurred drum latents from EnCodec encoding
MelodyConditioner(card, out_dim, device) L15-33 SymbolicCondition(melody) Projected salience matrix
JascoConditioningProvider L216-300 All condition types Tokenized and collated condition tensors

Usage Example

from audiocraft.models import JASCO
import torch

model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M')

# Chord progression input
chords = [("C", 0.0), ("Am", 2.5), ("F", 5.0), ("G", 7.5)]

# Drum audio input (pre-loaded)
drums_wav = torch.randn(1, 1, 320000)  # 10 seconds at 32kHz

# These are used internally by generate_music(), but can be called directly:
# attributes = model._prepare_temporal_conditions(
#     attributes=attrs, expected_length=500,
#     chords=chords, drums_wav=drums_wav, salience_matrix=None
# )

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment