Implementation:Facebookresearch Audiocraft JASCO get pretrained
Overview
JASCO_get_pretrained loads a pretrained JASCO model from HuggingFace Hub or a local path, assembling the compression model, flow matching language model, and chord mapping into a ready-to-use JASCO instance.
Implements
Principle:Facebookresearch_Audiocraft_JASCO_Model_Loading
API Signature
@staticmethod
def get_pretrained(
name: str = 'facebook/jasco-chords-drums-400M',
device=None,
chords_mapping_path: str = 'assets/chord_to_index_mapping.pkl'
) -> JASCO
Source: audiocraft/models/jasco.py, lines 43-64
Import:
from audiocraft.models import JASCO
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
name |
str |
'facebook/jasco-chords-drums-400M' |
HuggingFace model ID or local path to model directory |
device |
Optional[str] |
None |
Target device; auto-detects CUDA if available, falls back to CPU |
chords_mapping_path |
str |
'assets/chord_to_index_mapping.pkl' |
Path to the chord-to-index mapping pickle file |
Return Value
Returns a fully initialized JASCO instance with:
compression_model: Loaded EnCodec model for audio encoding/decodinglm: LoadedFlowMatchingModelfor vector field predictionchords_mapping: Dictionary mapping chord label strings to integer indices
Dependencies
torch-- for tensor operations and device managementhuggingface_hub-- for downloading model weights from HuggingFace Hubtorchdiffeq-- required by theFlowMatchingModelfor ODE integrationpickle-- for loading the chord mapping filedemucs-- required by theDrumsConditionerfor drum stem separation
How It Works
from audiocraft.models import JASCO
# Load default model
model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M')
# Load with explicit device
model = JASCO.get_pretrained('facebook/jasco-chords-drums-1B', device='cuda')
# Load model with melody support
model = JASCO.get_pretrained('facebook/jasco-chords-drums-melody-400M')
The implementation:
@staticmethod
def get_pretrained(name='facebook/jasco-chords-drums-400M', device=None,
chords_mapping_path='assets/chord_to_index_mapping.pkl'):
if device is None:
if torch.cuda.device_count():
device = 'cuda'
else:
device = 'cpu'
compression_model = load_compression_model(name, device=device)
lm = load_jasco_model(name, compression_model, device=device)
kwargs = {'name': name,
'compression_model': compression_model,
'lm': lm,
'chords_mapping_path': chords_mapping_path}
return JASCO(**kwargs)
Internal Loading Chain
The load_jasco_model() function in audiocraft/models/loaders.py (lines 158-172) handles JASCO-specific model construction:
def load_jasco_model(file_or_url_or_id, compression_model, device='cpu', cache_dir=None):
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
model = builders.get_jasco_model(cfg, compression_model)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
Key difference from MusicGen loading: the compression_model is passed to get_jasco_model() because the JASCO architecture needs to know the compression model's latent dimensions during construction.
JASCO Constructor
After get_pretrained() constructs the components, the JASCO constructor (lines 30-40) performs additional initialization:
def __init__(self, chords_mapping_path='assets/chord_to_index_mapping.pkl', **kwargs):
super().__init__(**kwargs)
# Fixed sequence length from config
self.duration = self.lm.cfg.dataset.segment_duration
# Load chord vocabulary
self.chords_mapping = pickle.load(open(chords_mapping_path, "rb"))
# Initialize generation parameters
self.set_generation_params()
Note that self.duration is set from the config's segment_duration rather than being user-configurable, reflecting JASCO's fixed-length generation design.