Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:LLMBook zh LLMBook zh github io LlamaRewardModel

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment, Reinforcement_Learning
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for training a LLaMA-based reward model with contrastive loss and LM regularization provided by the LLMBook repository.

Description

The LlamaRewardModel class extends LlamaForCausalLM by adding a reward_head (linear projection from hidden_size to scalar). It implements three key methods:

  • _forward_rmloss: Encodes input and projects to scalar reward via the reward head.
  • _forward_lmloss: Computes cross-entropy LM loss on response tokens as regularization.
  • forward: Combines both losses — computes rewards for positive and negative examples, then applies binary cross-entropy on the reward difference plus the LM loss.

Usage

Use this class when training a reward model for RLHF alignment. It requires paired preference data (positive and negative responses to the same prompt).

Code Reference

Source Location

  • Repository: LLMBook-zh
  • File: code/8.1 奖励模型训练.py
  • Lines: 7-79

Signature

class LlamaRewardModel(LlamaForCausalLM):
    def __init__(self, config: LlamaConfig):
        """
        Extends LlamaForCausalLM with:
            reward_head: nn.Linear(hidden_size, 1, bias=False)
        """

    def _forward_rmloss(self, input_ids, attention_mask, **kargs) -> Tensor:
        """Computes scalar reward for a sequence."""

    def _forward_lmloss(self, prompt_ids, lm_attn_mask, response_ids) -> Tensor:
        """Computes LM cross-entropy loss as regularization."""

    def forward(self, sent1_idx, attention_mask_1, sent2_idx, attention_mask_2,
                labels, prompt_ids, lm_attn_mask, response_ids, **kargs) -> Tensor:
        """
        Args:
            sent1_idx: Positive example token IDs.
            attention_mask_1: Mask for positive example.
            sent2_idx: Negative example token IDs.
            attention_mask_2: Mask for negative example.
            labels: Labels indicating positive position (all 0).
            prompt_ids: Concatenated prompt+response IDs for LM loss.
            lm_attn_mask: Attention mask for LM loss.
            response_ids: Target IDs for LM loss computation.

        Returns:
            Combined loss: rm_loss + lm_loss.
        """

Import

from reward_model import LlamaRewardModel

I/O Contract

Inputs

Name Type Required Description
sent1_idx LongTensor Yes Positive example: prompt + preferred response token IDs
attention_mask_1 Tensor Yes Attention mask for sent1_idx
sent2_idx LongTensor Yes Negative example: prompt + rejected response token IDs
attention_mask_2 Tensor Yes Attention mask for sent2_idx
labels Tensor Yes Labels (all 0, indicating positive is sent1)
prompt_ids LongTensor Yes Prompt+response IDs for LM regularization
lm_attn_mask Tensor Yes Attention mask for LM loss
response_ids LongTensor Yes Target IDs for LM cross-entropy

Outputs

Name Type Description
loss Tensor Combined rm_loss (BCE on reward difference) + lm_loss (cross-entropy)

Usage Examples

from reward_model import LlamaRewardModel
from transformers import LlamaConfig

config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
reward_model = LlamaRewardModel(config)

# Training with paired preference data
loss = reward_model(
    sent1_idx=positive_ids,
    attention_mask_1=positive_mask,
    sent2_idx=negative_ids,
    attention_mask_2=negative_mask,
    labels=torch.zeros(batch_size, seq_len),
    prompt_ids=prompt_response_ids,
    lm_attn_mask=lm_mask,
    response_ids=target_ids,
)
loss.backward()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment