Implementation:Microsoft DeepSpeedExamples Create DSVL Model
Appearance
- Implementation: Create_DSVL_Model
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (Pattern Doc) |
| Title | Create_DSVL_Model |
| Repository | Microsoft/DeepSpeedExamples |
| Application | DeepSpeed-VisualChat |
| File | applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py
|
| Lines | 32-377 |
| Language | Python |
| Status | Active |
Overview
Concrete tool for creating the composed DeepSpeed-VisualChat model with vision encoder, projection, and language decoder.
Code Reference
Factory Function: create_dsvl_model_and_transforms (Lines 32-97)
def create_dsvl_model_and_transforms(
text_tokenizer=None,
ds_config=None,
args=None):
assert args.vision_model_name_or_path is not None
assert args.lm_model_name_or_path is not None
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
lang_config = AutoConfig.from_pretrained(args.lm_model_name_or_path)
# Vision encoder loading (Qwen-VL or standard CLIP)
if 'qwen' in args.vision_model_name_or_path.lower():
vis_config = AutoConfig.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
vis_config = vis_config.vision_config
vis_encoder = VisionTransformer(
image_size=448,
patch_size=vis_config.patch_size,
width=vis_config.hidden_size,
layers=vis_config.num_hidden_layers,
heads=vis_config.num_attention_heads,
mlp_size=vis_config.intermediate_size,
output_dim=4096,
)
vis_encoder.load_state_dict(
torch.load(os.path.join(args.vision_model_name_or_path,
'pytorch_model.bin'),
map_location='cpu'),
strict=True)
vis_config.hidden_size = 4096
elif 'clip' in args.vision_model_name_or_path.lower():
vis_encoder = CLIPVisionModel.from_pretrained(
args.vision_model_name_or_path)
vis_config = vis_encoder.config
image_processor = CLIPImageProcessor.from_pretrained(
args.vision_model_name_or_path)
tokenizer = add_special_token(text_tokenizer)
tokenizer.pad_token = tokenizer.eos_token
# Language decoder loading (LLaMA family)
if 'llama' in args.lm_model_name_or_path.lower():
lang_config = LlamaConfig.from_pretrained(args.lm_model_name_or_path)
lang_config.enable_mmca_attention = args.enable_mmca_attention
lang_config.max_position_embeddings = args.max_seq_len
lang_decoder = LlamaForCausalLM.from_pretrained(
args.lm_model_name_or_path, config=lang_config)
lang_config.vocab_size = len(tokenizer)
lang_decoder.resize_token_embeddings(len(tokenizer))
model = DeepSpeedViLModel(vis_encoder, lang_decoder, tokenizer,
vis_config=vis_config,
decoder_name=decoder_name,
lang_config=lang_config,
max_seq_length=args.max_seq_len,
args=args)
return model, image_processor, tokenizer
Main Model Class: DeepSpeedViLModel (Lines 100-377)
class DeepSpeedViLModel(nn.Module):
def __init__(self, vis_encoder, lang_decoder, tokenizer,
vis_config=None, decoder_name='gpt2',
lang_config=None, max_seq_length=512, args=None):
super().__init__()
self.vis_encoder = vis_encoder
self.lang_decoder = lang_decoder
self.tokenizer = tokenizer
self.args = args
self._enable_special_token()
self.lang_config = lang_config
self._get_model_stat(decoder_name)
lang_embed, pos_embedding = self._languag_embedding()
self.pos_embedding = pos_embedding
self.max_seq_length = max_seq_length
if lang_embed is None:
self.lang_embed = nn.Embedding(
self.lang_config.vocab_size,
self.hidden_size,
self.pad_token_id)
else:
self.lang_embed = lang_embed
self.projection = self.build_projection(
vis_config, self.lang_config.hidden_size)
self._init_weight()
self.padding_embedding = None
self.vis_encoder_update = None
Forward Method (Lines 277-347)
def forward(self, img, lang,
attention_mask=None, input_labels=None,
image_num=1, past_key_values=None,
use_cache=False, output_attentions=False,
output_hidden_states=False, return_dict=True):
assert attention_mask is not None, "attention mask is required"
assert input_labels is not None, "input labels is required"
# Step 1: Encode images with vision encoder
if self.vis_encoder_update:
img_feature = self.vis_encoder(img)
else:
with torch.no_grad():
img_feature = self.vis_encoder(img)
if not isinstance(img_feature, torch.Tensor):
img_feature = img_feature.last_hidden_state
# Step 2: Project to language space
img_proj = self.projection(img_feature)
# Step 3: Concatenate with text embeddings
hidden_states, attention_mask, input_labels = self.concat(
img_proj, lang, attention_mask, input_labels, image_num)
# Step 4: Run language decoder
logits = self.lang_decoder(
input_ids=None,
inputs_embeds=hidden_states,
attention_mask=attention_mask,
labels=None,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict).logits
# Step 5: Compute cross-entropy loss (answer tokens only)
logits_shift = logits[..., :-1, :].contiguous().view(-1, self.vocab_size)
labels_shift = labels[..., 1:].contiguous().to(logits_shift.device).view(-1)
labels_index = labels_shift != -100
if torch.sum(labels_index) == 0:
logits_shift = logits_shift[-2:, :].contiguous()
labels_shift = labels_shift[-2:].contiguous()
else:
logits_shift = logits_shift[labels_index, :].contiguous()
labels_shift = labels_shift[labels_index].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits_shift, labels_shift)
return [loss,]
Concatenation Method (Lines 197-275)
def concat(self, img_proj, lang, attention_mask, input_labels,
image_num, do_generation=False):
"""Replace <image> placeholder tokens with actual visual features."""
output_lang = []
output_attention_mask = []
output_input_labels = []
# Split projected images by per-sample image counts
img_proj = split_tensor_by_a_list(img_proj, image_num)
for index in range(len(img_proj)):
cur_img = img_proj[index]
cur_lang = lang[index]
# Find <image> token positions
img_pos_list = cur_lang.eq(
self.DEFAULT_IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
cur_lang = self.lang_embed(cur_lang) # embed text tokens
# Insert visual features at <image> positions (in reverse order)
for img_i, img_pos in zip(cur_img,
torch.flip(img_pos_list, dims=(0,))):
lang_full = torch.cat((
cur_lang[:img_pos],
img_i,
cur_lang[img_pos+1:]), dim=0)
# Mark image positions with attention value 2
attention_mask_full = torch.cat((
attention_mask[:img_pos],
2 * torch.ones_like(img_i[:, 0]),
attention_mask[img_pos+1:]), dim=0)
# ...
# Pad to uniform length (divisible by 8)
# ...
return torch.cat(output_lang), torch.cat(output_attention_mask), \
torch.cat(output_input_labels)
Generation Method (Lines 349-372)
@torch.no_grad()
def generate(self, img, lang,
attention_mask=None, input_labels=None,
generation_length=128, generation_kwargs={}):
assert lang.size()[0] == 1, "only support batch size == 1 for now"
# ... encode, project, concatenate ...
output = self.lang_decoder.generate(
input_ids=None,
inputs_embeds=hidden_states,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id,
max_new_tokens=generation_length,
**generation_kwargs)
return (output,
self.tokenizer.batch_decode(output, skip_special_tokens=True)[0])
I/O Contract
create_dsvl_model_and_transforms
| Direction | Parameter | Type | Description |
|---|---|---|---|
| Input | text_tokenizer |
AutoTokenizer |
Base tokenizer from language model |
| Input | ds_config |
dict or None |
DeepSpeed configuration (for ZeRO-3 support) |
| Input | args |
argparse.Namespace |
Command-line arguments |
| Output | model |
DeepSpeedViLModel |
The composed multimodal model |
| Output | image_processor |
CLIPImageProcessor |
Image preprocessing pipeline |
| Output | tokenizer |
AutoTokenizer |
Extended tokenizer with special tokens |
DeepSpeedViLModel.forward
| Direction | Parameter | Type | Shape | Description |
|---|---|---|---|---|
| Input | img |
torch.Tensor |
[total_images, 3, H, W] |
Batch of all images (across all samples) |
| Input | lang |
torch.Tensor |
[batch, seq_len] |
Token IDs with <image> placeholders
|
| Input | attention_mask |
torch.Tensor |
[batch, seq_len] |
1 for real tokens, 0 for padding |
| Input | input_labels |
torch.Tensor |
[batch, seq_len] |
Labels (-100 for masked positions) |
| Input | image_num |
list[int] |
[batch] |
Number of images per sample |
| Output | (return) | list[torch.Tensor] |
[1] |
List containing the scalar loss |
Import Pattern
from utils.model.modeling_dsvl import create_dsvl_model_and_transforms
from utils.model.modeling_dsvl import DeepSpeedViLModel
Or via the package-level import used in the training script:
from utils.model import create_dsvl_model_and_transforms
Usage Example
# In training/main.py
tokenizer = AutoTokenizer.from_pretrained(args.lm_model_name_or_path,
fast_tokenizer=True)
tokenizer.padding_side = 'right'
model, image_processor, tokenizer = create_dsvl_model_and_transforms(
text_tokenizer=tokenizer,
args=args,
ds_config=ds_config)
# Optional: Apply LoRA to language decoder
if args.lang_lora_dim > 0:
model.lang_decoder = convert_linear_layer_to_lora(
model.lang_decoder,
args.lang_lora_module_name,
args.lang_lora_dim)
# Training forward pass
loss = model(
batch["image"].half(),
batch["input_ids"],
attention_mask=batch["attention_mask"],
input_labels=batch["labels"],
image_num=batch["image_num"],
)[0]
Internal Architecture
The _init_weight method controls the trainable/frozen split:
def _init_weight(self):
self.vis_encoder.requires_grad_(False) # frozen
self.lang_decoder.requires_grad_(False) # frozen (LoRA applied separately)
self.lang_embed.requires_grad_(True) # trainable
self.projection.requires_grad_(True) # trainable
if self.pos_embedding is not None:
self.pos_embedding.requires_grad_(True) # trainable
Dependencies
transformers-- AutoConfig, AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, CLIPImageProcessortorch-- Core PyTorch modulesutils.model.vis_proj-- VisProjection_vit, VisProjection_perceiverutils.model.third_party_model.hf_model.modeling_llama-- Custom LlamaForCausalLM with MMCA supportutils.model.third_party_model.qwen_clip.qwen_clip-- VisionTransformer for Qwen-VLutils.data-- build_dataset, DataCollatorPadToMaxLen, add_special_token
Related Pages
- Principle:Microsoft_DeepSpeedExamples_Multimodal_Model_Composition -- The theoretical basis for multimodal model composition
- Implementation:Microsoft_DeepSpeedExamples_Extract_Qwen_VL -- Provides the extracted vision encoder weights
- Implementation:Microsoft_DeepSpeedExamples_VisProjection -- The projection modules used within this model
- Implementation:Microsoft_DeepSpeedExamples_Build_Dataset -- Prepares training data for this model
- Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_VisualChat -- Initializes distributed training for this model
- Environment:Microsoft_DeepSpeedExamples_VisualChat_Training_Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment