Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples Create DSVL Model

From Leeroopedia


  1. Implementation: Create_DSVL_Model

Metadata

Field Value
Page Type Implementation (Pattern Doc)
Title Create_DSVL_Model
Repository Microsoft/DeepSpeedExamples
Application DeepSpeed-VisualChat
File applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py
Lines 32-377
Language Python
Status Active

Overview

Concrete tool for creating the composed DeepSpeed-VisualChat model with vision encoder, projection, and language decoder.

Code Reference

Factory Function: create_dsvl_model_and_transforms (Lines 32-97)

def create_dsvl_model_and_transforms(
        text_tokenizer=None,
        ds_config=None,
        args=None):
    assert args.vision_model_name_or_path is not None
    assert args.lm_model_name_or_path is not None

    if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
        dschf = HfDeepSpeedConfig(ds_config)

    lang_config = AutoConfig.from_pretrained(args.lm_model_name_or_path)

    # Vision encoder loading (Qwen-VL or standard CLIP)
    if 'qwen' in args.vision_model_name_or_path.lower():
        vis_config = AutoConfig.from_pretrained(
            "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
        vis_config = vis_config.vision_config
        vis_encoder = VisionTransformer(
            image_size=448,
            patch_size=vis_config.patch_size,
            width=vis_config.hidden_size,
            layers=vis_config.num_hidden_layers,
            heads=vis_config.num_attention_heads,
            mlp_size=vis_config.intermediate_size,
            output_dim=4096,
        )
        vis_encoder.load_state_dict(
            torch.load(os.path.join(args.vision_model_name_or_path,
                                    'pytorch_model.bin'),
                       map_location='cpu'),
            strict=True)
        vis_config.hidden_size = 4096
    elif 'clip' in args.vision_model_name_or_path.lower():
        vis_encoder = CLIPVisionModel.from_pretrained(
            args.vision_model_name_or_path)
        vis_config = vis_encoder.config

    image_processor = CLIPImageProcessor.from_pretrained(
        args.vision_model_name_or_path)
    tokenizer = add_special_token(text_tokenizer)
    tokenizer.pad_token = tokenizer.eos_token

    # Language decoder loading (LLaMA family)
    if 'llama' in args.lm_model_name_or_path.lower():
        lang_config = LlamaConfig.from_pretrained(args.lm_model_name_or_path)
        lang_config.enable_mmca_attention = args.enable_mmca_attention
        lang_config.max_position_embeddings = args.max_seq_len
        lang_decoder = LlamaForCausalLM.from_pretrained(
            args.lm_model_name_or_path, config=lang_config)

    lang_config.vocab_size = len(tokenizer)
    lang_decoder.resize_token_embeddings(len(tokenizer))

    model = DeepSpeedViLModel(vis_encoder, lang_decoder, tokenizer,
                              vis_config=vis_config,
                              decoder_name=decoder_name,
                              lang_config=lang_config,
                              max_seq_length=args.max_seq_len,
                              args=args)
    return model, image_processor, tokenizer

Main Model Class: DeepSpeedViLModel (Lines 100-377)

class DeepSpeedViLModel(nn.Module):
    def __init__(self, vis_encoder, lang_decoder, tokenizer,
                 vis_config=None, decoder_name='gpt2',
                 lang_config=None, max_seq_length=512, args=None):
        super().__init__()
        self.vis_encoder = vis_encoder
        self.lang_decoder = lang_decoder
        self.tokenizer = tokenizer
        self.args = args
        self._enable_special_token()

        self.lang_config = lang_config
        self._get_model_stat(decoder_name)
        lang_embed, pos_embedding = self._languag_embedding()
        self.pos_embedding = pos_embedding
        self.max_seq_length = max_seq_length

        if lang_embed is None:
            self.lang_embed = nn.Embedding(
                self.lang_config.vocab_size,
                self.hidden_size,
                self.pad_token_id)
        else:
            self.lang_embed = lang_embed

        self.projection = self.build_projection(
            vis_config, self.lang_config.hidden_size)
        self._init_weight()
        self.padding_embedding = None
        self.vis_encoder_update = None

Forward Method (Lines 277-347)

def forward(self, img, lang,
            attention_mask=None, input_labels=None,
            image_num=1, past_key_values=None,
            use_cache=False, output_attentions=False,
            output_hidden_states=False, return_dict=True):

    assert attention_mask is not None, "attention mask is required"
    assert input_labels is not None, "input labels is required"

    # Step 1: Encode images with vision encoder
    if self.vis_encoder_update:
        img_feature = self.vis_encoder(img)
    else:
        with torch.no_grad():
            img_feature = self.vis_encoder(img)
    if not isinstance(img_feature, torch.Tensor):
        img_feature = img_feature.last_hidden_state

    # Step 2: Project to language space
    img_proj = self.projection(img_feature)

    # Step 3: Concatenate with text embeddings
    hidden_states, attention_mask, input_labels = self.concat(
        img_proj, lang, attention_mask, input_labels, image_num)

    # Step 4: Run language decoder
    logits = self.lang_decoder(
        input_ids=None,
        inputs_embeds=hidden_states,
        attention_mask=attention_mask,
        labels=None,
        past_key_values=past_key_values,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict).logits

    # Step 5: Compute cross-entropy loss (answer tokens only)
    logits_shift = logits[..., :-1, :].contiguous().view(-1, self.vocab_size)
    labels_shift = labels[..., 1:].contiguous().to(logits_shift.device).view(-1)

    labels_index = labels_shift != -100
    if torch.sum(labels_index) == 0:
        logits_shift = logits_shift[-2:, :].contiguous()
        labels_shift = labels_shift[-2:].contiguous()
    else:
        logits_shift = logits_shift[labels_index, :].contiguous()
        labels_shift = labels_shift[labels_index].contiguous()

    loss_fct = CrossEntropyLoss()
    loss = loss_fct(logits_shift, labels_shift)
    return [loss,]

Concatenation Method (Lines 197-275)

def concat(self, img_proj, lang, attention_mask, input_labels,
           image_num, do_generation=False):
    """Replace <image> placeholder tokens with actual visual features."""
    output_lang = []
    output_attention_mask = []
    output_input_labels = []

    # Split projected images by per-sample image counts
    img_proj = split_tensor_by_a_list(img_proj, image_num)

    for index in range(len(img_proj)):
        cur_img = img_proj[index]
        cur_lang = lang[index]
        # Find <image> token positions
        img_pos_list = cur_lang.eq(
            self.DEFAULT_IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]

        cur_lang = self.lang_embed(cur_lang)  # embed text tokens
        # Insert visual features at <image> positions (in reverse order)
        for img_i, img_pos in zip(cur_img,
                                   torch.flip(img_pos_list, dims=(0,))):
            lang_full = torch.cat((
                cur_lang[:img_pos],
                img_i,
                cur_lang[img_pos+1:]), dim=0)
            # Mark image positions with attention value 2
            attention_mask_full = torch.cat((
                attention_mask[:img_pos],
                2 * torch.ones_like(img_i[:, 0]),
                attention_mask[img_pos+1:]), dim=0)
            # ...

    # Pad to uniform length (divisible by 8)
    # ...
    return torch.cat(output_lang), torch.cat(output_attention_mask), \
           torch.cat(output_input_labels)

Generation Method (Lines 349-372)

@torch.no_grad()
def generate(self, img, lang,
             attention_mask=None, input_labels=None,
             generation_length=128, generation_kwargs={}):
    assert lang.size()[0] == 1, "only support batch size == 1 for now"
    # ... encode, project, concatenate ...
    output = self.lang_decoder.generate(
        input_ids=None,
        inputs_embeds=hidden_states,
        attention_mask=attention_mask,
        pad_token_id=self.tokenizer.pad_token_id,
        max_new_tokens=generation_length,
        **generation_kwargs)
    return (output,
            self.tokenizer.batch_decode(output, skip_special_tokens=True)[0])

I/O Contract

create_dsvl_model_and_transforms

Direction Parameter Type Description
Input text_tokenizer AutoTokenizer Base tokenizer from language model
Input ds_config dict or None DeepSpeed configuration (for ZeRO-3 support)
Input args argparse.Namespace Command-line arguments
Output model DeepSpeedViLModel The composed multimodal model
Output image_processor CLIPImageProcessor Image preprocessing pipeline
Output tokenizer AutoTokenizer Extended tokenizer with special tokens

DeepSpeedViLModel.forward

Direction Parameter Type Shape Description
Input img torch.Tensor [total_images, 3, H, W] Batch of all images (across all samples)
Input lang torch.Tensor [batch, seq_len] Token IDs with <image> placeholders
Input attention_mask torch.Tensor [batch, seq_len] 1 for real tokens, 0 for padding
Input input_labels torch.Tensor [batch, seq_len] Labels (-100 for masked positions)
Input image_num list[int] [batch] Number of images per sample
Output (return) list[torch.Tensor] [1] List containing the scalar loss

Import Pattern

from utils.model.modeling_dsvl import create_dsvl_model_and_transforms
from utils.model.modeling_dsvl import DeepSpeedViLModel

Or via the package-level import used in the training script:

from utils.model import create_dsvl_model_and_transforms

Usage Example

# In training/main.py
tokenizer = AutoTokenizer.from_pretrained(args.lm_model_name_or_path,
                                          fast_tokenizer=True)
tokenizer.padding_side = 'right'

model, image_processor, tokenizer = create_dsvl_model_and_transforms(
    text_tokenizer=tokenizer,
    args=args,
    ds_config=ds_config)

# Optional: Apply LoRA to language decoder
if args.lang_lora_dim > 0:
    model.lang_decoder = convert_linear_layer_to_lora(
        model.lang_decoder,
        args.lang_lora_module_name,
        args.lang_lora_dim)

# Training forward pass
loss = model(
    batch["image"].half(),
    batch["input_ids"],
    attention_mask=batch["attention_mask"],
    input_labels=batch["labels"],
    image_num=batch["image_num"],
)[0]

Internal Architecture

The _init_weight method controls the trainable/frozen split:

def _init_weight(self):
    self.vis_encoder.requires_grad_(False)   # frozen
    self.lang_decoder.requires_grad_(False)  # frozen (LoRA applied separately)
    self.lang_embed.requires_grad_(True)     # trainable
    self.projection.requires_grad_(True)     # trainable
    if self.pos_embedding is not None:
        self.pos_embedding.requires_grad_(True)  # trainable

Dependencies

  • transformers -- AutoConfig, AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, CLIPImageProcessor
  • torch -- Core PyTorch modules
  • utils.model.vis_proj -- VisProjection_vit, VisProjection_perceiver
  • utils.model.third_party_model.hf_model.modeling_llama -- Custom LlamaForCausalLM with MMCA support
  • utils.model.third_party_model.qwen_clip.qwen_clip -- VisionTransformer for Qwen-VL
  • utils.data -- build_dataset, DataCollatorPadToMaxLen, add_special_token

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment