Implementation:Facebookresearch Audiocraft JASCO prepare temporal conditions
Overview
JASCO_prepare_temporal_conditions orchestrates the preparation of all temporal conditioning inputs (chords, drums, melody) by delegating to specialized internal methods. It transforms raw user inputs into ConditioningAttributes objects that the flow matching model consumes during generation.
Implements
Principle:Facebookresearch_Audiocraft_Temporal_Conditioning_Preparation
API Signature
@torch.no_grad()
def _prepare_temporal_conditions(
self,
attributes: List[ConditioningAttributes],
expected_length: int,
chords: Optional[List[Tuple[str, float]]],
drums_wav: Optional[torch.Tensor],
salience_matrix: Optional[torch.Tensor],
melody_bins: int = 53,
) -> List[ConditioningAttributes]
Source: audiocraft/models/jasco.py, lines 239-266
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
attributes |
List[ConditioningAttributes] |
required | Pre-constructed attributes containing text descriptions |
expected_length |
int |
required | Expected number of generated frames (segment_duration * frame_rate)
|
chords |
Optional[List[Tuple[str, float]]] |
required | Chord progression as (label, start_time) tuples, e.g., [("C", 0.0), ("F", 4.0)]
|
drums_wav |
Optional[torch.Tensor] |
required | Drum audio waveform of shape [B, C, T]
|
salience_matrix |
Optional[torch.Tensor] |
required | Melody salience matrix of shape [B, melody_bins, T]
|
melody_bins |
int |
53 |
Number of pitch bins in the salience matrix |
Return Value
Returns List[ConditioningAttributes] with symbolic and wav fields populated for all temporal conditions.
How It Works
The method delegates to three internal methods in sequence:
@torch.no_grad()
def _prepare_temporal_conditions(self, attributes, expected_length,
chords, drums_wav, salience_matrix,
melody_bins=53):
attributes = self._prepare_chord_conditions(attributes=attributes, chords=chords)
attributes = self._prepare_drums_conditions(attributes=attributes, drums_wav=drums_wav)
attributes = self._prepare_melody_conditions(attributes=attributes, melody=salience_matrix,
expected_length=expected_length,
melody_bins=melody_bins)
return attributes
Internal Method: _prepare_chord_conditions
Source: audiocraft/models/jasco.py, lines 137-173
Converts symbolic chord progressions into per-frame integer sequences:
def _prepare_chord_conditions(self, attributes, chords):
if chords is None or chords == []:
for att in attributes:
att.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
frame_chords=-1 * torch.ones(1, dtype=torch.int32))
return attributes
# Flip from (chord, start_time) to (start_time, chord)
chords_time_first = [(item[1], item[0]) for item in chords]
# Convert to per-frame integer sequence using chord vocabulary
frame_chords = construct_frame_chords(
min_timestamp=0,
chord_changes=chords_time_first,
mapping_dict=self.chords_mapping,
prev_chord='',
frame_rate=self.compression_model.frame_rate,
segment_duration=self.duration)
for att in attributes:
att.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
frame_chords=torch.tensor(frame_chords))
return attributes
Internal Method: _prepare_drums_conditions
Source: audiocraft/models/jasco.py, lines 175-211
Wraps drum audio into a WavCondition with proper padding/trimming:
@torch.no_grad()
def _prepare_drums_conditions(self, attributes, drums_wav):
for attr in attributes:
if drums_wav is None:
# Null condition: zero tensor
attr.wav[JascoCondConst.DRM.value] = WavCondition(
torch.zeros((1, 1, 1), device=self.device),
torch.tensor([0], device=self.device),
sample_rate=[self.sample_rate], path=[None])
else:
expected_length = self.lm.cfg.dataset.segment_duration * self.sample_rate
drums_wav = drums_wav[..., :expected_length] # trim
if drums_wav.shape[-1] < expected_length: # pad
diff = expected_length - drums_wav.shape[-1]
drums_wav = torch.cat((drums_wav,
torch.zeros((..., diff), device=drums_wav.device)), dim=-1)
attr.wav[JascoCondConst.DRM.value] = WavCondition(
drums_wav.to(device=self.device),
torch.tensor([drums_wav.shape[-1]], device=self.device),
sample_rate=[self.sample_rate], path=[None])
return attributes
Internal Method: _prepare_melody_conditions
Source: audiocraft/models/jasco.py, lines 213-237
Wraps the pre-computed salience matrix into a SymbolicCondition:
@torch.no_grad()
def _prepare_melody_conditions(self, attributes, melody, expected_length, melody_bins=53):
for attr in attributes:
if melody is None:
melody = torch.zeros((melody_bins, expected_length))
attr.symbolic[JascoCondConst.MLD.value] = SymbolicCondition(melody=melody)
return attributes
JASCO Conditioner Classes
The prepared conditions are later processed by these conditioner classes in audiocraft/modules/jasco_conditioners.py:
| Class | Source Lines | Input | Output |
|---|---|---|---|
ChordsEmbConditioner(card, out_dim, device) |
L36-56 | SymbolicCondition(frame_chords) |
Embedded chord vectors via nn.Embedding
|
DrumsConditioner(out_dim, sample_rate, blurring_factor=3) |
L59-214 | WavCondition with drum audio |
Blurred drum latents from EnCodec encoding |
MelodyConditioner(card, out_dim, device) |
L15-33 | SymbolicCondition(melody) |
Projected salience matrix |
JascoConditioningProvider |
L216-300 | All condition types | Tokenized and collated condition tensors |
Usage Example
from audiocraft.models import JASCO
import torch
model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M')
# Chord progression input
chords = [("C", 0.0), ("Am", 2.5), ("F", 5.0), ("G", 7.5)]
# Drum audio input (pre-loaded)
drums_wav = torch.randn(1, 1, 320000) # 10 seconds at 32kHz
# These are used internally by generate_music(), but can be called directly:
# attributes = model._prepare_temporal_conditions(
# attributes=attrs, expected_length=500,
# chords=chords, drums_wav=drums_wav, salience_matrix=None
# )