Implementation:CarperAI Trlx Default SFT Config
| Knowledge Sources | |
|---|---|
| Domains | Supervised_Learning, NLP, Configuration |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for creating default SFT training configurations provided by the trlx library.
Description
The default_sft_config() factory function returns a TRLConfig with defaults for supervised fine-tuning. It configures the SFTConfig method (which only needs generation kwargs for evaluation), sets the trainer to AccelerateSFTTrainer, and provides standard optimizer and scheduler settings. All layers are unfrozen by default (num_layers_unfrozen=-1).
Usage
Import this function when setting up supervised fine-tuning on text or instruction datasets. The returned config can be customized via TRLConfig.evolve() for nested parameter updates.
Code Reference
Source Location
- Repository: trlx
- File: trlx/data/default_configs.py
- Lines: L97-121
Signature
def default_sft_config() -> TRLConfig:
return TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=8,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AccelerateSFTTrainer",
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw",
kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6),
),
scheduler=SchedulerConfig(
name="cosine_annealing",
kwargs=dict(T_max=1e12, eta_min=1.0e-4),
),
method=SFTConfig(
name="sftconfig",
gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True),
),
)
Import
from trlx.data.default_configs import default_sft_config
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (none) | — | — | Factory function takes no arguments |
Outputs
| Name | Type | Description |
|---|---|---|
| return | TRLConfig | Fully configured TRLConfig with SFTConfig method, trainer set to AccelerateSFTTrainer |
Usage Examples
Basic SFT Config
from trlx.data.default_configs import default_sft_config
config = default_sft_config()
config.model.model_path = "EleutherAI/gpt-j-6B"
config.train.batch_size = 4
config.train.total_steps = 5000
Using evolve() for Alpaca-Style Config
from trlx.data.default_configs import default_sft_config
config = default_sft_config()
config = config.evolve(
train=dict(seq_length=512, batch_size=4, total_steps=2000),
model=dict(model_path="EleutherAI/gpt-j-6B"),
method=dict(gen_kwargs=dict(max_new_tokens=100)),
)