Implementation:Predibase Lorax Rotary Embedding
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements rotary positional embeddings (RoPE) with support for static, dynamic, YaRN, and SuRoPE/LongRoPE scaling strategies used across transformer attention layers.
Description
This module provides several classes for computing and applying rotary positional embeddings to query and key tensors in attention mechanisms:
PositionRotaryEmbedding: The base class that caches precomputed cos/sin tables indexed by sequence position. The forward method applies rotary embeddings in-place using platform-specific kernels: rotary_emb.apply_rotary on CUDA (from flash-attention) or ops.rotary_embedding on ROCm (from vLLM). It provides two classmethods:
- static: Creates an embedding from a model config, automatically dispatching to the appropriate subclass based on the rope_scaling config (linear, dynamic, yarn, su/longrope).
- load: Creates an embedding by loading inv_freq weights from a checkpoint.
The get_cos_sin method returns indexed cos/sin tensors for a batch of position IDs, with ROCm using float32 and CUDA matching the input dtype.
SuRotaryEmbedding: Implements the SuRoPE / LongRoPE scaling strategy, which uses separate short_inv_freq and long_inv_freq frequency tables. It selects the long frequencies when the sequence length exceeds original_max_position_embeddings.
DynamicPositionRotaryEmbedding: Implements dynamic NTK-aware scaling that adjusts the base frequency when the sequence exceeds max_position_embeddings, using the formula: newbase = base * ((factor * seqlen / max_pos) - (factor - 1)) ** (dim / (dim - 2)).
YarnPositionRotaryEmbedding: Implements the YaRN (Yet another RoPE extensioN) method, which blends interpolated and extrapolated frequencies using a linear ramp mask between correction bounds (beta_fast and beta_slow). Applies magnitude scaling via get_mscale.
Helper functions include _create_inv_freq for generating inverse frequency tensors, _get_rope_config for reading scaling config from the model or environment variables, find_correction_dim/range for YaRN correction bounds, linear_ramp_mask for YaRN interpolation blending, and get_mscale for magnitude scaling.
Usage
This module is used by all transformer model implementations in LoRAX that employ rotary positional embeddings. The PositionRotaryEmbedding.static factory is the primary entry point, automatically selecting the correct scaling strategy based on the model configuration.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/rotary.py
- Lines: 1-356
Signature
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, short_inv_freq, long_inv_freq, scaling_factor, original_max_position_embeddings):
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *,
extrapolation_factor, attn_factor, beta_fast, beta_slow):
Import
from lorax_server.layers.rotary import PositionRotaryEmbedding
I/O Contract
Inputs (forward)
| Name | Type | Required | Description |
|---|---|---|---|
| query | torch.Tensor | Yes | Query tensor to apply rotary embeddings to |
| key | torch.Tensor | Yes | Key tensor to apply rotary embeddings to |
| cos | torch.Tensor | Yes | Precomputed cosine values for the positions |
| sin | torch.Tensor | Yes | Precomputed sine values for the positions |
Inputs (get_cos_sin)
| Name | Type | Required | Description |
|---|---|---|---|
| position_ids | torch.Tensor | Yes | Position indices for the batch |
| max_s | int | Yes | Maximum sequence length for cache update |
| dtype | torch.dtype | Yes | Data type for cos/sin tensors |
Outputs (get_cos_sin)
| Name | Type | Description |
|---|---|---|
| cos | torch.Tensor | Cosine values indexed by position_ids, unsqueezed at dim 1 |
| sin | torch.Tensor | Sine values indexed by position_ids, unsqueezed at dim 1 |
Usage Examples
# Used internally by transformer attention layers
from lorax_server.layers.rotary import PositionRotaryEmbedding
rotary_emb = PositionRotaryEmbedding.static(config, dim=128, base=10000, device=device)
cos, sin = rotary_emb.get_cos_sin(position_ids, max_s=2048, dtype=torch.float16)
rotary_emb(query, key, cos, sin)