Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft LMModel generate

From Leeroopedia

Summary

LMModel.generate is the core autoregressive generation method of the MusicGen language model. It takes conditioning attributes and an optional prompt, then iteratively predicts discrete audio tokens one step at a time using the streaming transformer with codebook interleaving patterns. The method handles classifier-free guidance setup, pattern-based sequence construction, token sampling, and sequence reconstruction.

API Signature

@torch.no_grad()
def generate(
    self,
    prompt: Optional[torch.Tensor] = None,
    conditions: List[ConditioningAttributes] = [],
    num_samples: Optional[int] = None,
    max_gen_len: int = 256,
    use_sampling: bool = True,
    temp: float = 1.0,
    top_k: int = 250,
    top_p: float = 0.0,
    cfg_coef: Optional[float] = None,
    cfg_coef_beta: Optional[float] = None,
    two_step_cfg: Optional[bool] = None,
    remove_prompts: bool = False,
    check: bool = False,
    callback: Optional[Callable[[int, int], None]] = None,
) -> torch.Tensor

Parameters

Parameter Type Default Description
prompt Optional[torch.Tensor] None Prompt tokens of shape [B, K, T] for continuation. If None, generation starts from scratch.
conditions List[ConditioningAttributes] [] List of conditioning attributes (text, melody, style). One per batch sample.
num_samples Optional[int] None Number of samples to generate. Inferred from prompt or conditions if not specified.
max_gen_len int 256 Maximum generation length in tokens (frames). Typically computed as int(duration * frame_rate).
use_sampling bool True Use probabilistic sampling if True; use argmax decoding if False.
temp float 1.0 Temperature for softmax scaling during sampling.
top_k int 250 Top-k filtering parameter for token sampling.
top_p float 0.0 Top-p (nucleus) sampling threshold. When 0.0, top-k is used instead.
cfg_coef Optional[float] None Classifier-free guidance coefficient. If None, uses self.cfg_coef.
cfg_coef_beta Optional[float] None Double CFG beta coefficient for MusicGen-Style models.
two_step_cfg Optional[bool] None Use two-step CFG instead of batched CFG. If None, uses self.two_step_cfg.
remove_prompts bool False If True, remove prompt tokens from the returned output.
check bool False Enable runtime consistency checks on the generated sequence.
callback Optional[Callable] None Progress callback function receiving (current_step, total_steps).

Return Value

Type Description
torch.Tensor Generated discrete audio tokens of shape [B, K, T'] where T' is the output length (up to max_gen_len, minus prompt length if remove_prompts=True). Token values are integers in the range [0, card].

Source Location

  • File: audiocraft/models/lm.py, lines 421-587
  • Class: LMModel (extends StreamingModule)
  • Import: from audiocraft.models.lm import LMModel (typically accessed through MusicGen.lm)

Internal Workflow

The generate method proceeds through five major phases:

Phase 1: Input Validation and Setup (lines 462-477)

  • Asserts the model is not in training mode.
  • Determines num_samples from the prompt, conditions, or explicit parameter.
  • Resolves the device from the model's first parameter.

Phase 2: Classifier-Free Guidance Condition Preparation (lines 479-511)

Depending on the CFG configuration, conditions are prepared differently:

  • Double CFG (cfg_coef_beta is not None): Creates three condition sets -- full conditional, wav-only conditional (text dropped), and null conditional. These are concatenated and tokenized together. Batch size is tripled.
  • Standard CFG with two-step (two_step_cfg=True): Creates separate condition tensors for conditional and unconditional passes. They are processed independently during generation.
  • Standard CFG batched (default): Concatenates conditional + null conditions. Batch size is doubled. The single forward pass produces both conditional and unconditional logits.
  • No conditions: Uses empty condition tensors.

The condition preparation involves:

null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + null_conditions  # concatenate for batched CFG
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)

Phase 3: Pattern Sequence Construction (lines 513-534)

  • Creates an empty prompt if none is provided: [B, K, 0].
  • Obtains the codebook interleaving pattern from self.pattern_provider.get_pattern(max_gen_len).
  • Initializes gen_codes with the unknown token (-1) and fills in the prompt.
  • Builds the pattern sequence using pattern.build_pattern_sequence(), which maps [B, K, T] to [B, K, S].
  • Finds start_offset_sequence: the first pattern sequence step that includes the first timestep after the prompt.

Phase 4: Autoregressive Generation Loop (lines 536-566)

The loop iterates over pattern sequence positions from start_offset_sequence to gen_sequence_len:

with self.streaming():
    for offset in range(start_offset_sequence, gen_sequence_len):
        curr_sequence = gen_sequence[..., prev_offset:offset]
        next_token = self._sample_next_token(
            curr_sequence, cfg_conditions, unconditional_state,
            use_sampling, temp, top_k, top_p,
            cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta,
            two_step_cfg=two_step_cfg
        )
        # Apply valid mask and fill in generated tokens
        gen_sequence[..., offset:offset+1] = torch.where(
            gen_sequence[..., offset:offset+1] == unknown_token,
            next_token, gen_sequence[..., offset:offset+1]
        )

Key aspects of the loop:

  • Uses the streaming API which caches key-value states from previous transformer forward passes.
  • Only the new portion of the sequence (prev_offset:offset) is fed at each step.
  • _sample_next_token handles CFG logit interpolation and token sampling.
  • A validity mask ensures that masked positions (according to the pattern) retain the special token.
  • Prompt tokens are never overwritten.

Phase 5: Sequence Reconstruction (lines 568-587)

After generation completes:

  • Asserts that no unknown tokens remain.
  • Reverts the pattern sequence back to the original codebook layout using pattern.revert_pattern_sequence().
  • Trims prompt tokens if remove_prompts=True.
  • Validates that all output codes are in the valid range [0, card].

Related Classes

Class Location Role
CodebooksPatternProvider audiocraft/modules/codebooks_patterns.py, lines 272-302 Abstract base class for codebook interleaving patterns. Provides get_pattern(timesteps).
DelayedPatternProvider audiocraft/modules/codebooks_patterns.py, lines 305-356 Default pattern provider. Each codebook is delayed by its index (codebook 0 has delay 0, codebook 1 has delay 1, etc.).
ParallelPatternProvider audiocraft/modules/codebooks_patterns.py, line 359 All codebooks have zero delay (predicted simultaneously).
StreamingTransformer audiocraft/modules/transformer.py Transformer with key-value caching for efficient streaming generation.
ConditionFuser audiocraft/modules/conditioners.py Fuses encoded conditions with the sequence input (prepend, cross-attention, etc.).

Example Usage

# Typically called indirectly through MusicGen._generate_tokens():
gen_tokens = model.lm.generate(
    prompt_tokens,        # [B, K, T] or None
    attributes,           # List[ConditioningAttributes]
    callback=callback,
    max_gen_len=total_gen_len,
    use_sampling=True,
    temp=1.0,
    top_k=250,
    top_p=0.0,
    cfg_coef=3.0,
    two_step_cfg=False,
    cfg_coef_beta=None,
)
# gen_tokens: [B, K, max_gen_len]

Dependencies

  • torch - Core tensor operations, autograd disabled via @torch.no_grad()
  • xformers - Memory-efficient attention used within the StreamingTransformer

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment