Implementation:CarperAI Trlx ILQL Generate
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Offline_RL, Text_Generation |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for Q-value guided text generation provided by the ILQL model class in trlx.
Description
The generate() method on AutoModelForCausalLMWithILQLHeads implements Q-value guided autoregressive sampling. At each token step, it computes the base model logits, Q-values from the target Q-heads, and value estimates from the V-head. It then modifies the token distribution by adding advantage-weighted bonuses before sampling. The method supports top-k filtering, temperature scaling, KV-cache reuse, and PEFT adapter bypassing.
Usage
This method is called automatically during ILQL evaluation (via trainer.generate_eval()) and can be called directly for inference. It is used in place of HuggingFace's standard .generate() when Q-value guidance is desired.
Code Reference
Source Location
- Repository: trlx
- File: trlx/models/modeling_ilql.py
- Lines: L325-412
Signature
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple] = None,
beta: float = 1,
max_new_tokens: int = 32,
max_length: int = 1024,
temperature: float = 1,
top_k: int = 20,
logit_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.LongTensor:
"""
Generate samples with Q-value guided sampling.
At each step: pi = softmax((log_pi + beta * advantage) / temperature)
where advantage = Q(s,a) - V(s)
Args:
input_ids: Tokenized prompt input IDs.
attention_mask: Attention mask for padding.
beta: Advantage weighting strength (higher = more reward-seeking).
max_new_tokens: Maximum tokens to generate.
temperature: Sampling temperature (0 = greedy).
top_k: Top-k filtering on advantage-modified logits.
logit_mask: Optional mask for invalid tokens.
Returns:
Generated token sequences (prompt + completion).
"""
Import
from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Tokenized prompt [batch_size, seq_len] |
| attention_mask | torch.LongTensor | No | Padding mask (auto-computed if None) |
| beta | float | No | Advantage weighting strength (default 1) |
| max_new_tokens | int | No | Maximum new tokens to generate (default 32) |
| temperature | float | No | Sampling temperature (default 1, 0=greedy) |
| top_k | int | No | Top-k filtering (default 20) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.LongTensor | Generated sequences [batch_size, seq_len + new_tokens] |
Usage Examples
Q-Value Guided Generation
import torch
from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads
from transformers import AutoTokenizer
# Load trained ILQL model
model = AutoModelForCausalLMWithILQLHeads.from_pretrained("./ilql_checkpoint")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Tokenize prompts
prompts = ["Once upon a time", "The meaning of life is"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
# Generate with Q-value guidance
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
beta=1.0, # Moderate reward-seeking
max_new_tokens=50,
temperature=0.7,
top_k=20,
)
# Decode
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
Varying Beta for Reward-Fluency Tradeoff
# beta=0: Pure language model (no Q-value guidance)
outputs_base = model.generate(input_ids=inputs.input_ids, beta=0)
# beta=1: Moderate guidance
outputs_moderate = model.generate(input_ids=inputs.input_ids, beta=1)
# beta=4: Strong reward optimization
outputs_strong = model.generate(input_ids=inputs.input_ids, beta=4)