Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx ILQL Generate

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, Offline_RL, Text_Generation
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for Q-value guided text generation provided by the ILQL model class in trlx.

Description

The generate() method on AutoModelForCausalLMWithILQLHeads implements Q-value guided autoregressive sampling. At each token step, it computes the base model logits, Q-values from the target Q-heads, and value estimates from the V-head. It then modifies the token distribution by adding advantage-weighted bonuses before sampling. The method supports top-k filtering, temperature scaling, KV-cache reuse, and PEFT adapter bypassing.

Usage

This method is called automatically during ILQL evaluation (via trainer.generate_eval()) and can be called directly for inference. It is used in place of HuggingFace's standard .generate() when Q-value guidance is desired.

Code Reference

Source Location

  • Repository: trlx
  • File: trlx/models/modeling_ilql.py
  • Lines: L325-412

Signature

def generate(
    self,
    input_ids: torch.LongTensor,
    attention_mask: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Tuple] = None,
    beta: float = 1,
    max_new_tokens: int = 32,
    max_length: int = 1024,
    temperature: float = 1,
    top_k: int = 20,
    logit_mask: Optional[torch.Tensor] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
) -> torch.LongTensor:
    """
    Generate samples with Q-value guided sampling.

    At each step: pi = softmax((log_pi + beta * advantage) / temperature)
    where advantage = Q(s,a) - V(s)

    Args:
        input_ids: Tokenized prompt input IDs.
        attention_mask: Attention mask for padding.
        beta: Advantage weighting strength (higher = more reward-seeking).
        max_new_tokens: Maximum tokens to generate.
        temperature: Sampling temperature (0 = greedy).
        top_k: Top-k filtering on advantage-modified logits.
        logit_mask: Optional mask for invalid tokens.

    Returns:
        Generated token sequences (prompt + completion).
    """

Import

from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor Yes Tokenized prompt [batch_size, seq_len]
attention_mask torch.LongTensor No Padding mask (auto-computed if None)
beta float No Advantage weighting strength (default 1)
max_new_tokens int No Maximum new tokens to generate (default 32)
temperature float No Sampling temperature (default 1, 0=greedy)
top_k int No Top-k filtering (default 20)

Outputs

Name Type Description
return torch.LongTensor Generated sequences [batch_size, seq_len + new_tokens]

Usage Examples

Q-Value Guided Generation

import torch
from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads
from transformers import AutoTokenizer

# Load trained ILQL model
model = AutoModelForCausalLMWithILQLHeads.from_pretrained("./ilql_checkpoint")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Tokenize prompts
prompts = ["Once upon a time", "The meaning of life is"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

# Generate with Q-value guidance
outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    beta=1.0,          # Moderate reward-seeking
    max_new_tokens=50,
    temperature=0.7,
    top_k=20,
)

# Decode
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

Varying Beta for Reward-Fluency Tradeoff

# beta=0: Pure language model (no Q-value guidance)
outputs_base = model.generate(input_ids=inputs.input_ids, beta=0)

# beta=1: Moderate guidance
outputs_moderate = model.generate(input_ids=inputs.input_ids, beta=1)

# beta=4: Strong reward optimization
outputs_strong = model.generate(input_ids=inputs.input_ids, beta=4)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment