Implementation:Allenai Open instruct ModelConfig
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Software Engineering, MLOps |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for structuring model configuration settings into a typed dataclass provided by the Open Instruct library.
Description
The ModelConfig dataclass consolidates all model-related hyperparameters for reproducible training. It includes fields for the model checkpoint identity, attention implementation, precision, gradient checkpointing, and all LoRA/PEFT settings. It also includes quantization settings for QLoRA workflows. The __post_init__ method enforces invariants, such as disabling KV-cache when gradient checkpointing is enabled (since they are incompatible).
This dataclass is primarily used in the GRPO and DPO training pipelines within Open Instruct. The SFT pipeline (finetune.py) uses FlatArguments directly, but ModelConfig is the canonical model configuration for newer training scripts.
Usage
Create a ModelConfig instance to configure the model for training. It is typically instantiated from command-line arguments via HuggingFace's argument parsing and passed to the model loading functions.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/model_utils.py - Lines: L130-177
Signature
@dataclass
class ModelConfig:
model_name_or_path: str | None = None
"""The model checkpoint for weights initialization."""
model_revision: str | None = None
"""The specific model version to use (can be a branch name, tag name or commit id)."""
dtype: str | None = None
"""The data type to load the model under."""
attn_implementation: Literal["flash_attention_2", "sdpa"] = "flash_attention_2"
"""Which attention implementation to use."""
use_cache: bool | None = None
"""Whether to use cache in the model."""
gradient_checkpointing: bool = False
"""Whether to use gradient checkpointing in the model."""
# PEFT-related args
use_peft: bool = False
"""Whether to use PEFT or not for training."""
lora_r: int | None = 16
"""LoRA R value."""
lora_alpha: int | None = 32
"""LoRA alpha."""
lora_dropout: float | None = 0.05
"""LoRA dropout."""
lora_target_modules: list[str] | None = None
"""LoRA target modules."""
lora_modules_to_save: list[str] | None = None
"""Model layers to unfreeze & train."""
lora_task_type: str = "CAUSAL_LM"
"""The task_type to pass for LoRA."""
# Quantization args
load_in_8bit: bool = False
"""Use 8 bit precision for the base model (LoRA only)."""
load_in_4bit: bool = False
"""Use 4 bit precision for the base model (LoRA only)."""
bnb_4bit_quant_type: str | None = "nf4"
"""Quantization type (fp4 or nf4)."""
use_bnb_nested_quant: bool = False
"""Use nested quantization."""
Import
from open_instruct.model_utils import ModelConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str or None | Yes (for model loading) | HuggingFace model ID or local path to the pre-trained model weights. |
| model_revision | str or None | No | Specific model version (branch, tag, or commit hash) for reproducibility. |
| dtype | str or None | No | Data type override for model loading (e.g., "bfloat16"). If None, uses model default.
|
| attn_implementation | Literal["flash_attention_2", "sdpa"] | No | Attention backend. Defaults to "flash_attention_2".
|
| use_cache | bool or None | No | Enable KV-cache. Auto-disabled when gradient checkpointing is enabled. |
| gradient_checkpointing | bool | No | Enable gradient checkpointing to reduce memory usage. Defaults to False. |
| use_peft | bool | No | Whether to use PEFT (LoRA). Defaults to False. |
| lora_r | int or None | No | LoRA rank. Defaults to 16. |
| lora_alpha | int or None | No | LoRA scaling factor. Defaults to 32. |
| lora_dropout | float or None | No | LoRA dropout rate. Defaults to 0.05. |
| lora_target_modules | list[str] or None | No | Which model modules to apply LoRA to. If None, uses library defaults. |
| lora_modules_to_save | list[str] or None | No | Additional modules to unfreeze and train alongside LoRA adapters. |
| lora_task_type | str | No | PEFT task type. Defaults to "CAUSAL_LM".
|
| load_in_8bit | bool | No | Load base model in 8-bit. Defaults to False. |
| load_in_4bit | bool | No | Load base model in 4-bit (QLoRA). Defaults to False. |
| bnb_4bit_quant_type | str or None | No | Quantization type for 4-bit. Defaults to "nf4".
|
| use_bnb_nested_quant | bool | No | Use nested quantization for further memory savings. Defaults to False. |
Outputs
| Name | Type | Description |
|---|---|---|
| (dataclass instance) | ModelConfig | A fully configured model settings object that can be passed to model loading utilities. |
Usage Examples
Basic Usage
from open_instruct.model_utils import ModelConfig
config = ModelConfig(
model_name_or_path="allenai/Llama-3.1-Tulu-3-8B",
attn_implementation="flash_attention_2",
gradient_checkpointing=True,
)
# Access configuration fields
print(config.model_name_or_path)
print(config.use_cache) # None -> auto-set to False by __post_init__ due to gradient_checkpointing
LoRA Configuration
config = ModelConfig(
model_name_or_path="allenai/Llama-3.1-Tulu-3-8B",
use_peft=True,
lora_r=64,
lora_alpha=16,
lora_dropout=0.05,
lora_target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
gradient_checkpointing=True,
)