Implementation:CarperAI Trlx NeMo ILQL Model
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for running Implicit Language Q-Learning (ILQL) on NVIDIA NeMo's Megatron-GPT framework with model-parallel and pipeline-parallel support.
Description
The ILQLGPT class extends NeMo's MegatronGPTModel to add ILQL-specific Q-value and value heads (ParallelILQLHeads), pipeline-parallel training with custom loss computation, checkpoint loading with resharding support, and Q-value guided text generation. It uses tensor-parallel linear layers (ParallelLinear) and a combined language-model + ILQL-heads module (LMHeads). Supports sequence parallelism, activation checkpointing, and Megatron's distributed training infrastructure.
Usage
Use this model class when training ILQL on large-scale models (1B+ parameters) that require NeMo's Megatron distributed training backend. For smaller models using HuggingFace Accelerate, use the standard ILQL model in modeling_ilql.py instead.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/models/modeling_nemo_ilql.py
- Lines: 1-785
Signature
class ILQLGPT(MegatronGPTModel):
def __init__(
self,
ilql_config: ILQLConfig,
metric_fn: Optional[Callable] = None,
**kwargs,
):
"""
Args:
ilql_config: ILQL hyperparameters (alpha, two_qs, gen_kwargs, etc.).
metric_fn: Optional evaluation metric function.
**kwargs: Passed to MegatronGPTModel.
"""
def load_from_pretrained(self, checkpoint_dir: str):
"""Load weights from a pretrained checkpoint with pipeline resharding."""
def training_step(self, batch: ILQLBatch, batch_idx: int):
"""Execute one PPO training step with ILQL loss."""
def generate(
self,
inputs: dict,
length_params: LengthParam,
sampling_params: Optional[SamplingParam] = None,
) -> list:
"""Generate text using the model with Q-value guidance."""
Import
from trlx.models.modeling_nemo_ilql import ILQLGPT
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ilql_config | ILQLConfig | Yes | ILQL hyperparameters (alpha, two_qs, gen_kwargs) |
| metric_fn | Callable | No | Evaluation metric function |
| batch | ILQLBatch | Yes | Training batch with input_ids, actions, rewards, etc. |
| checkpoint_dir | str | No | Path to pretrained checkpoint for loading |
| inputs | dict | Yes | Generation inputs (context tokens) |
| length_params | LengthParam | Yes | Generation length parameters |
| sampling_params | SamplingParam | No | Generation sampling parameters |
Outputs
| Name | Type | Description |
|---|---|---|
| training_step returns | torch.Tensor | Training loss (reduced across data parallel group) |
| generate returns | list | Generated text sequences |
| validation_step returns | dict | Validation metrics including loss and generated samples |
Usage Examples
Instantiate ILQLGPT with NeMo Config
from omegaconf import OmegaConf
from trlx.models.modeling_ilql import ILQLConfig
from trlx.models.modeling_nemo_ilql import ILQLGPT
# 1. Load NeMo Megatron config
megatron_cfg = OmegaConf.load("configs/nemo_configs/megatron_1.3b.yaml")
# 2. Define ILQL config
ilql_config = ILQLConfig(
tau=0.7,
gamma=0.99,
cql_scale=0.1,
awac_scale=1.0,
alpha=0.001,
beta=0.0,
steps_for_target_q_sync=5,
two_qs=True,
gen_kwargs={"temperature": 0.7, "max_new_tokens": 64},
)
# 3. Create model (typically done by the NeMoILQLTrainer)
model = ILQLGPT(ilql_config=ilql_config, cfg=megatron_cfg.model)