Implementation:Eric mitchell Direct preference optimization Disable Dropout
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Regularization, Training_Stability |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for disabling all dropout layers in a PyTorch model provided by the direct-preference-optimization repository.
Description
The disable_dropout function iterates over all modules in a PyTorch model and sets the dropout probability to 0 for any nn.Dropout instance. This ensures deterministic forward passes during DPO and SFT training.
Usage
Call this function immediately after loading a model with AutoModelForCausalLM.from_pretrained and before passing the model to any trainer. Applied to both policy and reference models.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: utils.py
- Lines: 99-103
Signature
def disable_dropout(model: torch.nn.Module):
"""Disable dropout in a model."""
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
Import
from utils import disable_dropout
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | torch.nn.Module | Yes | Any PyTorch model; all nn.Dropout submodules will have p set to 0 |
Outputs
| Name | Type | Description |
|---|---|---|
| (in-place) | None | Modifies model in-place; all nn.Dropout.p values set to 0 |
Usage Examples
Disabling Dropout After Model Loading
import transformers
from utils import disable_dropout
model = transformers.AutoModelForCausalLM.from_pretrained("gpt2-large")
disable_dropout(model)
# Verify dropout is disabled
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
assert module.p == 0
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment