Implementation:Facebookresearch Audiocraft LMModel generate
Summary
LMModel.generate is the core autoregressive generation method of the MusicGen language model. It takes conditioning attributes and an optional prompt, then iteratively predicts discrete audio tokens one step at a time using the streaming transformer with codebook interleaving patterns. The method handles classifier-free guidance setup, pattern-based sequence construction, token sampling, and sequence reconstruction.
API Signature
@torch.no_grad()
def generate(
self,
prompt: Optional[torch.Tensor] = None,
conditions: List[ConditioningAttributes] = [],
num_samples: Optional[int] = None,
max_gen_len: int = 256,
use_sampling: bool = True,
temp: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: Optional[float] = None,
cfg_coef_beta: Optional[float] = None,
two_step_cfg: Optional[bool] = None,
remove_prompts: bool = False,
check: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
) -> torch.Tensor
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
prompt |
Optional[torch.Tensor] |
None |
Prompt tokens of shape [B, K, T] for continuation. If None, generation starts from scratch.
|
conditions |
List[ConditioningAttributes] |
[] |
List of conditioning attributes (text, melody, style). One per batch sample. |
num_samples |
Optional[int] |
None |
Number of samples to generate. Inferred from prompt or conditions if not specified.
|
max_gen_len |
int |
256 |
Maximum generation length in tokens (frames). Typically computed as int(duration * frame_rate).
|
use_sampling |
bool |
True |
Use probabilistic sampling if True; use argmax decoding if False.
|
temp |
float |
1.0 |
Temperature for softmax scaling during sampling. |
top_k |
int |
250 |
Top-k filtering parameter for token sampling. |
top_p |
float |
0.0 |
Top-p (nucleus) sampling threshold. When 0.0, top-k is used instead.
|
cfg_coef |
Optional[float] |
None |
Classifier-free guidance coefficient. If None, uses self.cfg_coef.
|
cfg_coef_beta |
Optional[float] |
None |
Double CFG beta coefficient for MusicGen-Style models. |
two_step_cfg |
Optional[bool] |
None |
Use two-step CFG instead of batched CFG. If None, uses self.two_step_cfg.
|
remove_prompts |
bool |
False |
If True, remove prompt tokens from the returned output.
|
check |
bool |
False |
Enable runtime consistency checks on the generated sequence. |
callback |
Optional[Callable] |
None |
Progress callback function receiving (current_step, total_steps).
|
Return Value
| Type | Description |
|---|---|
torch.Tensor |
Generated discrete audio tokens of shape [B, K, T'] where T' is the output length (up to max_gen_len, minus prompt length if remove_prompts=True). Token values are integers in the range [0, card].
|
Source Location
- File:
audiocraft/models/lm.py, lines 421-587 - Class:
LMModel(extendsStreamingModule) - Import:
from audiocraft.models.lm import LMModel(typically accessed throughMusicGen.lm)
Internal Workflow
The generate method proceeds through five major phases:
Phase 1: Input Validation and Setup (lines 462-477)
- Asserts the model is not in training mode.
- Determines
num_samplesfrom the prompt, conditions, or explicit parameter. - Resolves the device from the model's first parameter.
Phase 2: Classifier-Free Guidance Condition Preparation (lines 479-511)
Depending on the CFG configuration, conditions are prepared differently:
- Double CFG (
cfg_coef_beta is not None): Creates three condition sets -- full conditional, wav-only conditional (text dropped), and null conditional. These are concatenated and tokenized together. Batch size is tripled. - Standard CFG with two-step (
two_step_cfg=True): Creates separate condition tensors for conditional and unconditional passes. They are processed independently during generation. - Standard CFG batched (default): Concatenates conditional + null conditions. Batch size is doubled. The single forward pass produces both conditional and unconditional logits.
- No conditions: Uses empty condition tensors.
The condition preparation involves:
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + null_conditions # concatenate for batched CFG
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)
Phase 3: Pattern Sequence Construction (lines 513-534)
- Creates an empty prompt if none is provided:
[B, K, 0]. - Obtains the codebook interleaving pattern from
self.pattern_provider.get_pattern(max_gen_len). - Initializes
gen_codeswith the unknown token (-1) and fills in the prompt. - Builds the pattern sequence using
pattern.build_pattern_sequence(), which maps[B, K, T]to[B, K, S]. - Finds
start_offset_sequence: the first pattern sequence step that includes the first timestep after the prompt.
Phase 4: Autoregressive Generation Loop (lines 536-566)
The loop iterates over pattern sequence positions from start_offset_sequence to gen_sequence_len:
with self.streaming():
for offset in range(start_offset_sequence, gen_sequence_len):
curr_sequence = gen_sequence[..., prev_offset:offset]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state,
use_sampling, temp, top_k, top_p,
cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta,
two_step_cfg=two_step_cfg
)
# Apply valid mask and fill in generated tokens
gen_sequence[..., offset:offset+1] = torch.where(
gen_sequence[..., offset:offset+1] == unknown_token,
next_token, gen_sequence[..., offset:offset+1]
)
Key aspects of the loop:
- Uses the streaming API which caches key-value states from previous transformer forward passes.
- Only the new portion of the sequence (
prev_offset:offset) is fed at each step. _sample_next_tokenhandles CFG logit interpolation and token sampling.- A validity mask ensures that masked positions (according to the pattern) retain the special token.
- Prompt tokens are never overwritten.
Phase 5: Sequence Reconstruction (lines 568-587)
After generation completes:
- Asserts that no unknown tokens remain.
- Reverts the pattern sequence back to the original codebook layout using
pattern.revert_pattern_sequence(). - Trims prompt tokens if
remove_prompts=True. - Validates that all output codes are in the valid range
[0, card].
Related Classes
| Class | Location | Role |
|---|---|---|
CodebooksPatternProvider |
audiocraft/modules/codebooks_patterns.py, lines 272-302 |
Abstract base class for codebook interleaving patterns. Provides get_pattern(timesteps).
|
DelayedPatternProvider |
audiocraft/modules/codebooks_patterns.py, lines 305-356 |
Default pattern provider. Each codebook is delayed by its index (codebook 0 has delay 0, codebook 1 has delay 1, etc.). |
ParallelPatternProvider |
audiocraft/modules/codebooks_patterns.py, line 359 |
All codebooks have zero delay (predicted simultaneously). |
StreamingTransformer |
audiocraft/modules/transformer.py |
Transformer with key-value caching for efficient streaming generation. |
ConditionFuser |
audiocraft/modules/conditioners.py |
Fuses encoded conditions with the sequence input (prepend, cross-attention, etc.). |
Example Usage
# Typically called indirectly through MusicGen._generate_tokens():
gen_tokens = model.lm.generate(
prompt_tokens, # [B, K, T] or None
attributes, # List[ConditioningAttributes]
callback=callback,
max_gen_len=total_gen_len,
use_sampling=True,
temp=1.0,
top_k=250,
top_p=0.0,
cfg_coef=3.0,
two_step_cfg=False,
cfg_coef_beta=None,
)
# gen_tokens: [B, K, max_gen_len]
Dependencies
torch- Core tensor operations, autograd disabled via@torch.no_grad()xformers- Memory-efficient attention used within theStreamingTransformer
Related Pages
- Principle:Facebookresearch_Audiocraft_Autoregressive_Token_Generation
- Implementation:Facebookresearch_Audiocraft_MusicGen_prepare_tokens_and_attributes - Prepares the conditions and prompt tokens consumed by this method.
- Implementation:Facebookresearch_Audiocraft_MusicGen_set_generation_params - Provides the sampling parameters unpacked into this method's arguments.
- Implementation:Facebookresearch_Audiocraft_EncodecModel_decode - Decodes the output tokens into audio waveforms.
- Environment:Facebookresearch_Audiocraft_XFormers_Memory_Efficient_Attention