Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Facebookresearch Audiocraft LMModel generate

From Leeroopedia
Revision as of 12:33, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Facebookresearch_Audiocraft_LMModel_generate.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Summary

LMModel.generate is the core autoregressive generation method of the MusicGen language model. It takes conditioning attributes and an optional prompt, then iteratively predicts discrete audio tokens one step at a time using the streaming transformer with codebook interleaving patterns. The method handles classifier-free guidance setup, pattern-based sequence construction, token sampling, and sequence reconstruction.

API Signature

@torch.no_grad()
def generate(
    self,
    prompt: Optional[torch.Tensor] = None,
    conditions: List[ConditioningAttributes] = [],
    num_samples: Optional[int] = None,
    max_gen_len: int = 256,
    use_sampling: bool = True,
    temp: float = 1.0,
    top_k: int = 250,
    top_p: float = 0.0,
    cfg_coef: Optional[float] = None,
    cfg_coef_beta: Optional[float] = None,
    two_step_cfg: Optional[bool] = None,
    remove_prompts: bool = False,
    check: bool = False,
    callback: Optional[Callable[[int, int], None]] = None,
) -> torch.Tensor

Parameters

Parameter Type Default Description
prompt Optional[torch.Tensor] None Prompt tokens of shape [B, K, T] for continuation. If None, generation starts from scratch.
conditions List[ConditioningAttributes] [] List of conditioning attributes (text, melody, style). One per batch sample.
num_samples Optional[int] None Number of samples to generate. Inferred from prompt or conditions if not specified.
max_gen_len int 256 Maximum generation length in tokens (frames). Typically computed as int(duration * frame_rate).
use_sampling bool True Use probabilistic sampling if True; use argmax decoding if False.
temp float 1.0 Temperature for softmax scaling during sampling.
top_k int 250 Top-k filtering parameter for token sampling.
top_p float 0.0 Top-p (nucleus) sampling threshold. When 0.0, top-k is used instead.
cfg_coef Optional[float] None Classifier-free guidance coefficient. If None, uses self.cfg_coef.
cfg_coef_beta Optional[float] None Double CFG beta coefficient for MusicGen-Style models.
two_step_cfg Optional[bool] None Use two-step CFG instead of batched CFG. If None, uses self.two_step_cfg.
remove_prompts bool False If True, remove prompt tokens from the returned output.
check bool False Enable runtime consistency checks on the generated sequence.
callback Optional[Callable] None Progress callback function receiving (current_step, total_steps).

Return Value

Type Description
torch.Tensor Generated discrete audio tokens of shape [B, K, T'] where T' is the output length (up to max_gen_len, minus prompt length if remove_prompts=True). Token values are integers in the range [0, card].

Source Location

  • File: audiocraft/models/lm.py, lines 421-587
  • Class: LMModel (extends StreamingModule)
  • Import: from audiocraft.models.lm import LMModel (typically accessed through MusicGen.lm)

Internal Workflow

The generate method proceeds through five major phases:

Phase 1: Input Validation and Setup (lines 462-477)

  • Asserts the model is not in training mode.
  • Determines num_samples from the prompt, conditions, or explicit parameter.
  • Resolves the device from the model's first parameter.

Phase 2: Classifier-Free Guidance Condition Preparation (lines 479-511)

Depending on the CFG configuration, conditions are prepared differently:

  • Double CFG (cfg_coef_beta is not None): Creates three condition sets -- full conditional, wav-only conditional (text dropped), and null conditional. These are concatenated and tokenized together. Batch size is tripled.
  • Standard CFG with two-step (two_step_cfg=True): Creates separate condition tensors for conditional and unconditional passes. They are processed independently during generation.
  • Standard CFG batched (default): Concatenates conditional + null conditions. Batch size is doubled. The single forward pass produces both conditional and unconditional logits.
  • No conditions: Uses empty condition tensors.

The condition preparation involves:

null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + null_conditions  # concatenate for batched CFG
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)

Phase 3: Pattern Sequence Construction (lines 513-534)

  • Creates an empty prompt if none is provided: [B, K, 0].
  • Obtains the codebook interleaving pattern from self.pattern_provider.get_pattern(max_gen_len).
  • Initializes gen_codes with the unknown token (-1) and fills in the prompt.
  • Builds the pattern sequence using pattern.build_pattern_sequence(), which maps [B, K, T] to [B, K, S].
  • Finds start_offset_sequence: the first pattern sequence step that includes the first timestep after the prompt.

Phase 4: Autoregressive Generation Loop (lines 536-566)

The loop iterates over pattern sequence positions from start_offset_sequence to gen_sequence_len:

with self.streaming():
    for offset in range(start_offset_sequence, gen_sequence_len):
        curr_sequence = gen_sequence[..., prev_offset:offset]
        next_token = self._sample_next_token(
            curr_sequence, cfg_conditions, unconditional_state,
            use_sampling, temp, top_k, top_p,
            cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta,
            two_step_cfg=two_step_cfg
        )
        # Apply valid mask and fill in generated tokens
        gen_sequence[..., offset:offset+1] = torch.where(
            gen_sequence[..., offset:offset+1] == unknown_token,
            next_token, gen_sequence[..., offset:offset+1]
        )

Key aspects of the loop:

  • Uses the streaming API which caches key-value states from previous transformer forward passes.
  • Only the new portion of the sequence (prev_offset:offset) is fed at each step.
  • _sample_next_token handles CFG logit interpolation and token sampling.
  • A validity mask ensures that masked positions (according to the pattern) retain the special token.
  • Prompt tokens are never overwritten.

Phase 5: Sequence Reconstruction (lines 568-587)

After generation completes:

  • Asserts that no unknown tokens remain.
  • Reverts the pattern sequence back to the original codebook layout using pattern.revert_pattern_sequence().
  • Trims prompt tokens if remove_prompts=True.
  • Validates that all output codes are in the valid range [0, card].

Related Classes

Class Location Role
CodebooksPatternProvider audiocraft/modules/codebooks_patterns.py, lines 272-302 Abstract base class for codebook interleaving patterns. Provides get_pattern(timesteps).
DelayedPatternProvider audiocraft/modules/codebooks_patterns.py, lines 305-356 Default pattern provider. Each codebook is delayed by its index (codebook 0 has delay 0, codebook 1 has delay 1, etc.).
ParallelPatternProvider audiocraft/modules/codebooks_patterns.py, line 359 All codebooks have zero delay (predicted simultaneously).
StreamingTransformer audiocraft/modules/transformer.py Transformer with key-value caching for efficient streaming generation.
ConditionFuser audiocraft/modules/conditioners.py Fuses encoded conditions with the sequence input (prepend, cross-attention, etc.).

Example Usage

# Typically called indirectly through MusicGen._generate_tokens():
gen_tokens = model.lm.generate(
    prompt_tokens,        # [B, K, T] or None
    attributes,           # List[ConditioningAttributes]
    callback=callback,
    max_gen_len=total_gen_len,
    use_sampling=True,
    temp=1.0,
    top_k=250,
    top_p=0.0,
    cfg_coef=3.0,
    two_step_cfg=False,
    cfg_coef_beta=None,
)
# gen_tokens: [B, K, max_gen_len]

Dependencies

  • torch - Core tensor operations, autograd disabled via @torch.no_grad()
  • xformers - Memory-efficient attention used within the StreamingTransformer

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment