Implementation:Facebookresearch Audiocraft CompressionModel get pretrained
Overview
CompressionModel.get_pretrained is the static factory method for loading pretrained neural audio codecs (EnCodec, DAC, HuggingFace models) to serve as audio tokenizers in the MusicGen training pipeline. It dispatches to the appropriate loading mechanism based on the model name and returns a CompressionModel instance ready for encoding and decoding audio.
Source Location
| Property | Value |
|---|---|
| Source file | audiocraft/models/encodec.py lines 88-122
|
| Import | from audiocraft.models import CompressionModel
|
| Module | audiocraft.models.encodec
|
| Builder function | audiocraft/models/builders.py lines 70-91 (get_compression_model)
|
API
get_pretrained
@staticmethod
CompressionModel.get_pretrained(
name: str,
device: Union[torch.device, str] = 'cpu'
) -> CompressionModel
Dispatch Logic
| Name Pattern | Action |
|---|---|
'dac_44khz' or 'dac_24khz' |
Loads DAC model via dac.utils.load_model()
|
'debug_compression_model' |
Loads a debug/test compression model |
| Existing local path | Loads AudioCraft checkpoint via loaders.load_compression_model()
|
Any other string (e.g., 'facebook/encodec_32khz') |
Loads HuggingFace EnCodec via HFEncodecModel.from_pretrained()
|
The returned model is always moved to the specified device and set to eval() mode.
get_compression_model (from config)
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel
Builds a compression model from a Hydra config (used when training a compression model from scratch, not when loading pretrained).
CompressionModel Abstract Interface
All compression models must implement:
class CompressionModel(ABC, nn.Module):
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Encode audio [B, C, T] -> (codes [B, K, T_s], scale)"""
...
def decode(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Decode codes [B, K, T_s] -> audio [B, C, T]"""
...
def decode_latent(self, codes: torch.Tensor):
"""Decode codes to continuous latent space."""
...
# Properties
channels: int # Number of audio channels
frame_rate: float # Frames per second
sample_rate: int # Audio sample rate
cardinality: int # Codebook size (e.g., 2048)
num_codebooks: int # Active number of codebooks
total_codebooks: int # Total available codebooks
Inputs and Outputs
Inputs:
name-- Model identifier string. Examples:'facebook/encodec_32khz'-- HuggingFace EnCodec at 32 kHz'facebook/encodec_24khz'-- HuggingFace EnCodec at 24 kHz'dac_44khz'-- Descript Audio Codec at 44.1 kHz'dac_24khz'-- Descript Audio Codec at 24 kHz- Local file path to an AudioCraft-exported checkpoint
device-- Target device (default'cpu')
Outputs:
CompressionModelinstance in eval mode with:encode()method mapping audio tensors to discrete codesdecode()method mapping codes back to audio tensors- Properties for
frame_rate,sample_rate,cardinality,num_codebooks
Concrete Implementations
| Class | Description | Source |
|---|---|---|
EncodecModel |
Native AudioCraft EnCodec with SEANet encoder/decoder and RVQ | encodec.py:L125-260
|
DAC |
Wrapper around Descript Audio Codec | encodec.py:L262-321
|
HFEncodecCompressionModel |
Wrapper around HuggingFace transformers.EncodecModel |
encodec.py:L323-394
|
InterleaveStereoCompressionModel |
Wrapper for stereo by interleaving left/right codebooks | encodec.py:L397-507
|
Usage in MusicGen Solver
In MusicGenSolver.build_model():
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
# Verify compatibility
assert self.compression_model.sample_rate == self.cfg.sample_rate
assert self.cfg.transformer_lm.card == self.compression_model.cardinality
assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks
During the training step, encoding is performed with:
with torch.no_grad():
audio_tokens, scale = self.compression_model.encode(audio)
assert scale is None, "Scaled compression model not supported with LM."
Dependencies
torch,torch.nn-- neural network basetransformers.EncodecModel-- HuggingFace EnCodecdac(optional) -- Descript Audio Codeceinops-- tensor rearrangement (for stereo interleaving)audiocraft.quantization-- RVQ quantizer implementations