Implementation:Predibase Lorax Watermark Logits Processor
| Knowledge Sources | |
|---|---|
| Domains | Decoding, Safety |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a watermarking logits processor based on the "A Watermark for Large Language Models" paper that biases token generation toward a pseudo-random "green list" of tokens, enabling statistical detection of machine-generated text.
Description
This module implements the watermarking scheme from Kirchenbauer et al. (2023). It is built as a Hugging Face LogitsProcessor subclass.
WatermarkLogitsProcessor (LogitsProcessor): The main class that modifies logit scores during decoding to embed a statistical watermark. Key methods:
- __init__(gamma, delta, hash_key, device): Initializes the processor with watermark parameters. gamma (default 0.5, configurable via WATERMARK_GAMMA env var) controls the fraction of the vocabulary assigned to the "green list". delta (default 2.0, configurable via WATERMARK_DELTA env var) controls the logit bias added to green list tokens. hash_key (default 15485863, a large prime) provides sufficient bit width for RNG seeding.
- _seed_rng(input_ids): Seeds a PyTorch Generator using the hash key multiplied by the last token in the input sequence. This makes the green list deterministic given the preceding token, enabling detection.
- _get_greenlist_ids(input_ids, max_value, device): Seeds the RNG, then generates a random permutation of the full vocabulary and selects the first gamma * vocab_size token IDs as the green list.
- _calc_greenlist_mask(scores, greenlist_token_ids): Creates a boolean mask tensor marking which positions in the scores correspond to green list tokens.
- _bias_greenlist_logits(scores, greenlist_mask, greenlist_bias): Adds the delta bias to all logit scores corresponding to green list tokens.
- __call__(input_ids, scores): The main entry point called during generation. Computes the green list from the current context, creates the mask, applies the bias, and returns the modified scores.
Usage
This processor is integrated into the LoRAX decoding pipeline. When watermarking is enabled for a request, the WatermarkLogitsProcessor is added to the logits processing chain. It runs after the model produces raw logits and before sampling, biasing the probability distribution toward green list tokens without significantly affecting text quality. The watermark can be statistically detected by a verifier who knows the hash key and gamma/delta parameters.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/watermark.py - Lines: 1-84
Signature
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863,
device: str = "cpu",
)
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor])
def _get_greenlist_ids(self, input_ids, max_value: int, device: torch.device) -> List[int]
@staticmethod
def _calc_greenlist_mask(scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor
@staticmethod
def _bias_greenlist_logits(scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor
def __call__(self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor) -> torch.FloatTensor
Import
from lorax_server.utils.watermark import WatermarkLogitsProcessor
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| gamma | float | No | Fraction of vocabulary in the green list (default 0.5, env: WATERMARK_GAMMA) |
| delta | float | No | Logit bias added to green list tokens (default 2.0, env: WATERMARK_DELTA) |
| hash_key | int | No | Large prime number for RNG seeding (default 15485863) |
| device | str | No | Device for RNG generator (default "cpu") |
| input_ids | Union[List[int], torch.LongTensor] | Yes | Current token sequence used to seed the green list RNG |
| scores | torch.FloatTensor | Yes | Raw logit scores from the model of shape (batch, vocab_size) |
Outputs
| Name | Type | Description |
|---|---|---|
| scores | torch.FloatTensor | Modified logit scores with delta bias added to green list tokens |
Usage Examples
# Internal usage in the decoding pipeline
from lorax_server.utils.watermark import WatermarkLogitsProcessor
processor = WatermarkLogitsProcessor(
gamma=0.5,
delta=2.0,
device="cpu",
)
# Called during generation for each token
modified_scores = processor(input_ids=current_tokens, scores=raw_logits)