Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct ModelConfig

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Software Engineering, MLOps
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for structuring model configuration settings into a typed dataclass provided by the Open Instruct library.

Description

The ModelConfig dataclass consolidates all model-related hyperparameters for reproducible training. It includes fields for the model checkpoint identity, attention implementation, precision, gradient checkpointing, and all LoRA/PEFT settings. It also includes quantization settings for QLoRA workflows. The __post_init__ method enforces invariants, such as disabling KV-cache when gradient checkpointing is enabled (since they are incompatible).

This dataclass is primarily used in the GRPO and DPO training pipelines within Open Instruct. The SFT pipeline (finetune.py) uses FlatArguments directly, but ModelConfig is the canonical model configuration for newer training scripts.

Usage

Create a ModelConfig instance to configure the model for training. It is typically instantiated from command-line arguments via HuggingFace's argument parsing and passed to the model loading functions.

Code Reference

Source Location

  • Repository: Open Instruct
  • File: open_instruct/model_utils.py
  • Lines: L130-177

Signature

@dataclass
class ModelConfig:
    model_name_or_path: str | None = None
    """The model checkpoint for weights initialization."""
    model_revision: str | None = None
    """The specific model version to use (can be a branch name, tag name or commit id)."""
    dtype: str | None = None
    """The data type to load the model under."""
    attn_implementation: Literal["flash_attention_2", "sdpa"] = "flash_attention_2"
    """Which attention implementation to use."""
    use_cache: bool | None = None
    """Whether to use cache in the model."""
    gradient_checkpointing: bool = False
    """Whether to use gradient checkpointing in the model."""

    # PEFT-related args
    use_peft: bool = False
    """Whether to use PEFT or not for training."""
    lora_r: int | None = 16
    """LoRA R value."""
    lora_alpha: int | None = 32
    """LoRA alpha."""
    lora_dropout: float | None = 0.05
    """LoRA dropout."""
    lora_target_modules: list[str] | None = None
    """LoRA target modules."""
    lora_modules_to_save: list[str] | None = None
    """Model layers to unfreeze & train."""
    lora_task_type: str = "CAUSAL_LM"
    """The task_type to pass for LoRA."""

    # Quantization args
    load_in_8bit: bool = False
    """Use 8 bit precision for the base model (LoRA only)."""
    load_in_4bit: bool = False
    """Use 4 bit precision for the base model (LoRA only)."""
    bnb_4bit_quant_type: str | None = "nf4"
    """Quantization type (fp4 or nf4)."""
    use_bnb_nested_quant: bool = False
    """Use nested quantization."""

Import

from open_instruct.model_utils import ModelConfig

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str or None Yes (for model loading) HuggingFace model ID or local path to the pre-trained model weights.
model_revision str or None No Specific model version (branch, tag, or commit hash) for reproducibility.
dtype str or None No Data type override for model loading (e.g., "bfloat16"). If None, uses model default.
attn_implementation Literal["flash_attention_2", "sdpa"] No Attention backend. Defaults to "flash_attention_2".
use_cache bool or None No Enable KV-cache. Auto-disabled when gradient checkpointing is enabled.
gradient_checkpointing bool No Enable gradient checkpointing to reduce memory usage. Defaults to False.
use_peft bool No Whether to use PEFT (LoRA). Defaults to False.
lora_r int or None No LoRA rank. Defaults to 16.
lora_alpha int or None No LoRA scaling factor. Defaults to 32.
lora_dropout float or None No LoRA dropout rate. Defaults to 0.05.
lora_target_modules list[str] or None No Which model modules to apply LoRA to. If None, uses library defaults.
lora_modules_to_save list[str] or None No Additional modules to unfreeze and train alongside LoRA adapters.
lora_task_type str No PEFT task type. Defaults to "CAUSAL_LM".
load_in_8bit bool No Load base model in 8-bit. Defaults to False.
load_in_4bit bool No Load base model in 4-bit (QLoRA). Defaults to False.
bnb_4bit_quant_type str or None No Quantization type for 4-bit. Defaults to "nf4".
use_bnb_nested_quant bool No Use nested quantization for further memory savings. Defaults to False.

Outputs

Name Type Description
(dataclass instance) ModelConfig A fully configured model settings object that can be passed to model loading utilities.

Usage Examples

Basic Usage

from open_instruct.model_utils import ModelConfig

config = ModelConfig(
    model_name_or_path="allenai/Llama-3.1-Tulu-3-8B",
    attn_implementation="flash_attention_2",
    gradient_checkpointing=True,
)

# Access configuration fields
print(config.model_name_or_path)
print(config.use_cache)  # None -> auto-set to False by __post_init__ due to gradient_checkpointing

LoRA Configuration

config = ModelConfig(
    model_name_or_path="allenai/Llama-3.1-Tulu-3-8B",
    use_peft=True,
    lora_r=64,
    lora_alpha=16,
    lora_dropout=0.05,
    lora_target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    gradient_checkpointing=True,
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment