Implementation:Facebookresearch Audiocraft MusicGen get pretrained validation
Overview
MusicGen_get_pretrained_validation uses the MusicGen.get_pretrained() API to validate exported model integrity by loading from a local export directory and optionally running a test generation. This is the same API used for production model loading, repurposed here as a validation step in the export pipeline.
Implements
Principle:Facebookresearch_Audiocraft_Exported_Model_Validation
API Signature
@staticmethod
def get_pretrained(
name: str = 'facebook/musicgen-melody',
device=None
) -> MusicGen
Source: audiocraft/models/musicgen.py, lines 57-94
Import:
from audiocraft.models import MusicGen
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
name |
str |
'facebook/musicgen-melody' |
For validation, pass the local path to the export directory (e.g., '/path/to/my_exported_model')
|
device |
Optional[str] |
None |
Target device; auto-detects CUDA if available, falls back to CPU |
Return Value
Returns a fully constructed MusicGen instance with loaded compression model and language model, confirming export integrity.
How It Works
from audiocraft.models import MusicGen
# Step 1: Load the exported model from local directory
model = MusicGen.get_pretrained('/path/to/my_exported_model', device='cpu')
# Step 2: Optionally run a short generation to validate inference
model.set_generation_params(duration=2) # short duration for quick validation
audio = model.generate(['test prompt'])
assert audio.shape[0] == 1 # one sample
assert audio.shape[-1] > 0 # non-empty audio
print("Export validation passed!")
Internal Loading Flow
When get_pretrained() is called with a local directory path, the following chain executes:
@staticmethod
def get_pretrained(name='facebook/musicgen-melody', device=None):
if device is None:
device = 'cuda' if torch.cuda.device_count() else 'cpu'
if name in _HF_MODEL_CHECKPOINTS_MAP:
name = _HF_MODEL_CHECKPOINTS_MAP[name]
lm = load_lm_model(name, device=device)
# -> calls load_lm_model_ckpt(name)
# -> calls _get_state_dict(name, filename="state_dict.bin")
# -> detects local directory, loads name/state_dict.bin
compression_model = load_compression_model(name, device=device)
# -> calls load_compression_model_ckpt(name)
# -> calls _get_state_dict(name, filename="compression_state_dict.bin")
# -> detects local directory, loads name/compression_state_dict.bin
return MusicGen(name, compression_model, lm)
Validation vs. Production Loading
| Aspect | Validation Use | Production Use |
|---|---|---|
| Input | Local export directory path | HuggingFace Hub model ID or pre-defined name |
| Purpose | Verify export integrity before distribution | Load model for generation |
| Error handling | Errors indicate export problems to fix | Errors indicate deployment issues |
| Generation test | Short test prompt to validate forward pass | Full-length generation for end users |
| API | Identical MusicGen.get_pretrained() |
Identical MusicGen.get_pretrained()
|
Error Conditions
Common validation failures and their meanings:
| Error | Likely Cause |
|---|---|
FileNotFoundError for state_dict.bin |
Language model was not exported or was saved with wrong filename |
FileNotFoundError for compression_state_dict.bin |
Compression model export step was skipped |
RuntimeError: size mismatch during load_state_dict |
Configuration does not match the exported weights (architecture mismatch) |
KeyError: 'best_state' |
Checkpoint was not properly exported (raw training checkpoint passed instead of exported one) |