Implementation:Microsoft DeepSpeedExamples Domino GPT Model
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Language Modeling, Distributed Training |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
A GPT-2 language model wrapper adapted from Megatron-LM for the DeepSpeed Domino distributed training framework, supporting pipeline parallelism with pre/post-processing stages.
Description
The GPTModel class extends MegatronModule to implement a GPT-2 language model for use with the DeepSpeed Domino distributed training system. It wraps the Megatron-LM get_language_model factory with causal attention masking (AttnMaskType.causal) and supports pipeline parallelism through configurable pre_process and post_process flags that control whether embedding and output-head layers are included in the current pipeline stage.
The forward method passes input tokens through the language model and, when post_process is enabled, calls post_language_model_processing to compute the final logits or loss. The post_language_model_processing function uses parallel_lm_logits for tensor-parallel logit computation and vocab_parallel_cross_entropy for efficient vocabulary-parallel loss calculation, supporting optional fp16_lm_cross_entropy for mixed-precision training. It handles the [s, b, h] to [b, s, h] tensor transpositions required by the Megatron layout convention.
The model supports both tied and untied embedding-output weight configurations via the untie_embeddings_and_output_weights argument. Custom state_dict_for_save_checkpoint and load_state_dict methods ensure correct serialization of the language model and optional word embedding weights across pipeline stages.
Usage
Use this model class when training GPT models with the DeepSpeed Domino framework, which provides optimized scheduling for hybrid data, tensor, and pipeline parallelism. It is instantiated by the model_builder function in pretrain_gpt.py and passed to the Domino pretrain function.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/DeepSpeed-Domino/domino/gpt_model.py
- Lines: 1-122
Signature
def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy) -> Tensor:
class GPTModel(MegatronModule):
def __init__(self,
config,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
def set_input_tensor(self, input_tensor) -> None:
def forward(self, input_ids, position_ids, attention_mask,
retriever_input_ids=None, retriever_position_ids=None,
retriever_attn_mask=None, labels=None,
tokentype_ids=None, inference_params=None) -> Tensor:
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False) -> dict:
def load_state_dict(self, state_dict, strict=True) -> None:
Import
from domino.gpt_model import GPTModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | TransformerConfig | Yes | Megatron transformer configuration object |
| num_tokentypes | int | No | Number of token types for embeddings (default: 0) |
| parallel_output | bool | No | Whether to keep output tensor parallel (default: True) |
| pre_process | bool | No | Include embedding layer in this pipeline stage (default: True) |
| post_process | bool | No | Include output head in this pipeline stage (default: True) |
| input_ids | Tensor | Yes | Input token IDs of shape [batch, seq_len] |
| position_ids | Tensor | Yes | Position IDs of shape [batch, seq_len] |
| attention_mask | Tensor | Yes | Attention mask tensor |
| labels | Tensor | No | Target labels; if provided, returns loss instead of logits |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor | Logits of shape [batch, seq_len, vocab] when labels is None, or cross-entropy loss of shape [batch, seq_len] when labels are provided |
Usage Examples
from domino.gpt_model import GPTModel
from megatron.arguments import core_transformer_config_from_args
from megatron import get_args
args = get_args()
config = core_transformer_config_from_args(args)
model = GPTModel(
config=config,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True
)
# Forward pass with loss computation
loss = model(input_ids, position_ids, attention_mask, labels=labels)