Implementation:Lm sys FastChat Condense Rotary Embedding
| Knowledge Sources | |
|---|---|
| Domains | Positional Encoding, LLM, Context Extension |
| Last Updated | 2026-02-07 06:00 GMT |
Overview
Monkey patches LLaMA rotary positional embeddings to extend the effective context length through position index condensing.
Description
The llama_condense_monkey_patch module implements a technique for extending the context window of LLaMA models beyond their original training length. It achieves this by replacing the standard LlamaRotaryEmbedding class in the transformers library with a custom CondenseRotaryEmbedding that divides position indices by a configurable ratio, effectively "condensing" longer sequences into the original embedding space.
The CondenseRotaryEmbedding class inherits from torch.nn.Module and precomputes cosine and sine caches for rotary position embeddings. During initialization, it scales the maximum position embeddings by the condensing ratio and divides the position index tensor t by that same ratio. This means a model trained with a 2048-token context can handle sequences of length 2048 * ratio by mapping the extended positions back into the original frequency domain. The embedding computation uses the standard RoPE formula: inverse frequencies are computed as 1 / (base^(2i/dim)), then combined with position indices via outer product and cached as cosine/sine pairs.
The replace_llama_with_condense function uses functools.partial to create a version of CondenseRotaryEmbedding with the ratio pre-bound, and then globally replaces transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with this patched class. This must be called before model instantiation to take effect. The technique is adapted from the SuperHOT approach by kaiokendev.
Usage
Use this module when you need to run inference on LLaMA-based models with sequences longer than their original training context length. Call replace_llama_with_condense(ratio) before loading the model, where ratio is the factor by which you want to extend the context (e.g., ratio=4 extends a 2048-context model to 8192 tokens). This is a global monkey patch and affects all subsequent LLaMA model instantiations in the process.
Code Reference
Source Location
- Repository: Lm_sys_FastChat
- File: fastchat/model/llama_condense_monkey_patch.py
- Lines: 1-71
Signature
class CondenseRotaryEmbedding(torch.nn.Module):
def __init__(self, dim: int, ratio: float, max_position_embeddings: int = 2048,
base: int = 10000, device=None) -> None:
...
def forward(self, x: torch.Tensor, seq_len: int = None) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (cos_cached, sin_cached) sliced to seq_len."""
...
def replace_llama_with_condense(ratio: float) -> None:
"""Globally replaces LlamaRotaryEmbedding with condensed version."""
...
Import
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ratio | float | Yes | Condensing ratio that determines context extension factor (e.g., 4 extends 2048 to 8192) |
| dim | int | Yes (class) | Dimensionality of the rotary embedding (head dimension) |
| max_position_embeddings | int | No | Maximum original position embeddings, defaults to 2048 |
| base | int | No | Base frequency for the inverse frequency computation, defaults to 10000 |
| device | torch.device | No | Device for tensor allocation, defaults to None |
Outputs
| Name | Type | Description |
|---|---|---|
| cos_cached | torch.Tensor | Cosine values of shape [1, 1, seq_len, dim] for rotary position encoding |
| sin_cached | torch.Tensor | Sine values of shape [1, 1, seq_len, dim] for rotary position encoding |
Usage Examples
# Apply the monkey patch before loading any LLaMA model
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
# Extend context by 4x (2048 -> 8192 tokens)
replace_llama_with_condense(ratio=4)
# Now load the model normally - it will use condensed rotary embeddings
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# The model can now handle sequences up to 8192 tokens
# despite being trained on 2048-token contexts
Related Pages
- Principle:Lm_sys_FastChat_Condensed_Rotary_Embedding
- Implements: Principle:Lm_sys_FastChat_Condensed_Rotary_Embedding
- Lm_sys_FastChat_Apply_Delta_Weights - Model weight manipulation in the same module directory
- Lm_sys_FastChat_Huggingface_API_Inference - Inference pipeline that may use extended-context models