Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime Torch NN Module Construction

From Leeroopedia


Overview

Standard PyTorch torch.nn.Module construction, producing a model instance ready for wrapping with ORTModule for accelerated training.

Metadata

Field Value
Implementation Name Torch_NN_Module_Construction
Type External Tool Doc
Language Python
API Standard torch.nn.Module construction: model = MyModel() or model = build_model()
Import import torch
Domain Accelerated_Training, PyTorch_Integration
Repository microsoft/onnxruntime
Source Reference docs/ORTModule_Training_Guidelines.md:L32-33
Last Updated 2026-02-10

Description

This implementation covers the standard PyTorch model construction step that precedes ORTModule wrapping. The model is constructed using standard PyTorch conventions -- there are no ONNX Runtime-specific modifications required at this stage.

The model must satisfy the following requirements for compatibility with ORTModule:

  • Must be a subclass of torch.nn.Module
  • The forward() method must accept and return tensors (or nested structures of tensors)
  • All learnable parameters must be registered via nn.Parameter or contained in submodules
  • Dynamic control flow based on tensor values should be avoided where possible for optimal ONNX tracing

API Signature

import torch
import torch.nn as nn

# Option 1: Direct instantiation
model = MyModel(config)

# Option 2: Factory function
model = build_model()

Key Parameters

Parameter Type Description
Model class torch.nn.Module subclass Any standard PyTorch model class
Constructor args Varies Model-specific configuration (hidden size, number of layers, etc.)

I/O Contract

Direction Type Description
Input Model configuration Architecture hyperparameters passed to the model constructor
Output torch.nn.Module instance Constructed model with initialized parameters, ready for ORTModule wrapping

Code Reference

From docs/ORTModule_Training_Guidelines.md:

	model = build_model()

+	from onnxruntime.training.ortmodule import ORTModule
+	model = ORTModule(model)

The first line (model = build_model()) represents the standard model construction step covered by this implementation.

Usage Example

import torch
import torch.nn as nn

# Example 1: Simple feedforward model
class FeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

model = FeedForward(784, 256, 10)

# Example 2: Using a pretrained model from a library
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Example 3: Using a factory function
def build_model():
    config = ModelConfig(hidden_size=768, num_layers=12, num_heads=12)
    return TransformerModel(config)

model = build_model()

# Model is now ready for ORTModule wrapping

Implements

Principle:Microsoft_Onnxruntime_PyTorch_Model_Preparation

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment