Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL Qwen3OmniMoeModel

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Multimodal, Vision_Language
Last Updated 2026-02-07 20:00 GMT

Overview

Qwen3 Omni multimodal MoE model implementation supporting audio, image, and video inputs for distributed training with Megatron-Core.

Description

modeling_qwen3_omni.py implements Qwen3OmniMoeModel, a multimodal Mixture-of-Experts model that extends Qwen3VLModel to support audio modality in addition to images and videos. The model is registered as qwen3_omni_moe via the @register_model decorator.

The architecture consists of:

  • Audio encoder: A Qwen3OmniMoeAudioEncoder initialized on the first pipeline stage (pre_process) that processes raw audio features into embeddings. Uses SDPA attention and supports gradient checkpointing for memory optimization.
  • Vision encoder: A Qwen3OmniMoeVisionEncoder also on the first pipeline stage, inherited from Qwen3VLModel, that processes images and videos with deepstack visual embeddings.
  • Language model: The Megatron-Core GPT decoder inherited from Qwen3VLGPTModel.
  • Optional talker and code2wav modules: On the last pipeline stage (post_process) when enable_audio_output is enabled, for text-to-speech generation.

The construct_inputs_embeds method (lines 88-226) is the core multimodal fusion function. It:

  1. Delegates image/video processing to the parent Qwen3VLModel.construct_inputs_embeds
  2. Processes audio features by extracting relevant audio segments based on input_ranges (sub-sequences assigned to this pipeline/sequence parallel rank)
  3. Runs the audio encoder on collected features
  4. Scatters audio embeddings back into the combined input embedding tensor using masked_scatter

The audio processing logic handles complex scenarios including:

  • Multiple audio segments across batch samples
  • Audio features split across sub-ranges in pipeline parallel settings
  • Deduplication of audio features already processed in previous sub-ranges
  • Proper index tracking for feature-to-embedding mapping

The forward method (lines 228-305) handles:

  • Automatic rope index computation for all modalities (image, video, audio)
  • Context parallelism batch slicing
  • Vision and audio feature injection on the first pipeline stage
  • Fallback to standard decoder forward on non-first pipeline stages

Usage

Use this model class for training multimodal MoE models that process audio, images, and video simultaneously. It is instantiated via AutoModel using the qwen3_omni_moe model type registration.

Code Reference

Source Location

Signature

@register_model("qwen3_omni_moe")
class Qwen3OmniMoeModel(Qwen3VLModel):
    config_class = Qwen3OmniMoeConfig

Key Methods

__init__

def __init__(self, config: "Qwen3OmniMoeConfig", **kwargs)  # lines 16-86

Initializes the model with:

  • Parent Qwen3VLGPTModel initialization
  • Audio encoder (Qwen3OmniMoeAudioEncoder) on the first pipeline stage with SDPA attention and gradient checkpointing for full recomputation
  • Vision encoder (Qwen3OmniMoeVisionEncoder) on the first pipeline stage with optional gradient checkpointing
  • Optional talker and code2wav modules on the last pipeline stage for audio output
  • Rope index computation methods bound from the HF Qwen3OmniMoePreTrainedModelForConditionalGeneration
  • All encoder parameters marked with the sequence_parallel attribute

construct_inputs_embeds

def construct_inputs_embeds(
    self,
    input_ids: "torch.LongTensor",
    inputs_embeds: "torch.FloatTensor",
    pixel_values: "torch.Tensor",
    grid_thw: "torch.LongTensor",
    pixel_values_videos: "torch.Tensor",
    video_grid_thw: "torch.LongTensor",
    input_features: "torch.Tensor",
    feature_lens: "torch.Tensor",
    feature_attention_mask: "torch.Tensor",
    input_ranges: List[List[int]],
    image_token_id: int,
    video_token_id: int,
    audio_token_id: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]  # lines 88-226

Processes all modalities and merges them into a single embedding tensor. Currently images and videos cannot be processed simultaneously (assertion enforced). Returns (inputs_embeds, visual_pos_masks, deepstack_visual_embeds).

forward

def forward(
    self,
    input_ids: "torch.Tensor",
    position_ids: Optional["torch.Tensor"] = None,
    attention_mask: Optional["torch.Tensor"] = None,
    decoder_input: Optional["torch.Tensor"] = None,
    labels: Optional["torch.Tensor"] = None,
    pixel_values: Optional["torch.Tensor"] = None,
    pixel_values_videos: Optional["torch.Tensor"] = None,
    image_grid_thw: Optional["torch.LongTensor"] = None,
    video_grid_thw: Optional["torch.LongTensor"] = None,
    input_features: Optional["torch.Tensor"] = None,
    feature_attention_mask: Optional["torch.Tensor"] = None,
    **kwargs,
) -> "torch.Tensor"  # lines 228-305

Full forward pass. Computes rope indices for all modalities, applies context parallelism, embeds tokens, injects visual and audio features, then runs the transformer decoder.

Import

import torch
from megatron.core import mpu
from mcore_adapter.models.qwen3_omni.modeling_qwen3_omni import Qwen3OmniMoeModel
from mcore_adapter.models.qwen3_omni.config_qwen3_omni import Qwen3OmniMoeConfig

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor Yes Token IDs with special tokens for image, video, and audio placeholders
pixel_values torch.Tensor No Image pixel values for vision encoder
pixel_values_videos torch.Tensor No Video pixel values for vision encoder
image_grid_thw torch.LongTensor No Grid dimensions (temporal, height, width) for images
video_grid_thw torch.LongTensor No Grid dimensions for videos
input_features torch.Tensor No Audio features in shape (batch, frequency, frames)
feature_attention_mask torch.Tensor No Attention mask for audio features
labels torch.Tensor No Target labels for language modeling loss
attention_mask torch.Tensor No Attention mask for the sequence
position_ids torch.Tensor No Position IDs (auto-computed from rope index if not provided)

Outputs

Name Type Description
output torch.Tensor Model output logits or loss depending on pipeline stage

Usage Examples

from mcore_adapter.models import AutoModel
from mcore_adapter.training_args import TrainingArguments

# Load Qwen3-Omni MoE model with distributed training
args = TrainingArguments(
    tensor_model_parallel_size=4,
    pipeline_model_parallel_size=2,
    expert_model_parallel_size=2,
    bf16=True,
    output_dir="/tmp/output",
)
model = AutoModel.from_pretrained("Qwen/Qwen3-Omni-MoE", args)

# Forward pass with multimodal inputs
output = model(
    input_ids=input_ids,
    pixel_values=pixel_values,
    image_grid_thw=image_grid_thw,
    input_features=audio_features,
    feature_attention_mask=audio_mask,
    labels=labels,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment