Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Watermark Logits Processor

From Leeroopedia


Knowledge Sources
Domains Decoding, Safety
Last Updated 2026-02-08 00:00 GMT

Overview

Implements a watermarking logits processor based on the "A Watermark for Large Language Models" paper that biases token generation toward a pseudo-random "green list" of tokens, enabling statistical detection of machine-generated text.

Description

This module implements the watermarking scheme from Kirchenbauer et al. (2023). It is built as a Hugging Face LogitsProcessor subclass.

WatermarkLogitsProcessor (LogitsProcessor): The main class that modifies logit scores during decoding to embed a statistical watermark. Key methods:

  • __init__(gamma, delta, hash_key, device): Initializes the processor with watermark parameters. gamma (default 0.5, configurable via WATERMARK_GAMMA env var) controls the fraction of the vocabulary assigned to the "green list". delta (default 2.0, configurable via WATERMARK_DELTA env var) controls the logit bias added to green list tokens. hash_key (default 15485863, a large prime) provides sufficient bit width for RNG seeding.
  • _seed_rng(input_ids): Seeds a PyTorch Generator using the hash key multiplied by the last token in the input sequence. This makes the green list deterministic given the preceding token, enabling detection.
  • _get_greenlist_ids(input_ids, max_value, device): Seeds the RNG, then generates a random permutation of the full vocabulary and selects the first gamma * vocab_size token IDs as the green list.
  • _calc_greenlist_mask(scores, greenlist_token_ids): Creates a boolean mask tensor marking which positions in the scores correspond to green list tokens.
  • _bias_greenlist_logits(scores, greenlist_mask, greenlist_bias): Adds the delta bias to all logit scores corresponding to green list tokens.
  • __call__(input_ids, scores): The main entry point called during generation. Computes the green list from the current context, creates the mask, applies the bias, and returns the modified scores.

Usage

This processor is integrated into the LoRAX decoding pipeline. When watermarking is enabled for a request, the WatermarkLogitsProcessor is added to the logits processing chain. It runs after the model produces raw logits and before sampling, biasing the probability distribution toward green list tokens without significantly affecting text quality. The watermark can be statistically detected by a verifier who knows the hash key and gamma/delta parameters.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/watermark.py
  • Lines: 1-84

Signature

class WatermarkLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        gamma: float = GAMMA,
        delta: float = DELTA,
        hash_key: int = 15485863,
        device: str = "cpu",
    )
    def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor])
    def _get_greenlist_ids(self, input_ids, max_value: int, device: torch.device) -> List[int]
    @staticmethod
    def _calc_greenlist_mask(scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor
    @staticmethod
    def _bias_greenlist_logits(scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor
    def __call__(self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor) -> torch.FloatTensor

Import

from lorax_server.utils.watermark import WatermarkLogitsProcessor

I/O Contract

Inputs

Name Type Required Description
gamma float No Fraction of vocabulary in the green list (default 0.5, env: WATERMARK_GAMMA)
delta float No Logit bias added to green list tokens (default 2.0, env: WATERMARK_DELTA)
hash_key int No Large prime number for RNG seeding (default 15485863)
device str No Device for RNG generator (default "cpu")
input_ids Union[List[int], torch.LongTensor] Yes Current token sequence used to seed the green list RNG
scores torch.FloatTensor Yes Raw logit scores from the model of shape (batch, vocab_size)

Outputs

Name Type Description
scores torch.FloatTensor Modified logit scores with delta bias added to green list tokens

Usage Examples

# Internal usage in the decoding pipeline
from lorax_server.utils.watermark import WatermarkLogitsProcessor

processor = WatermarkLogitsProcessor(
    gamma=0.5,
    delta=2.0,
    device="cpu",
)

# Called during generation for each token
modified_scores = processor(input_ids=current_tokens, scores=raw_logits)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment