Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Rotary Embedding

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Implements rotary positional embeddings (RoPE) with support for static, dynamic, YaRN, and SuRoPE/LongRoPE scaling strategies used across transformer attention layers.

Description

This module provides several classes for computing and applying rotary positional embeddings to query and key tensors in attention mechanisms:

PositionRotaryEmbedding: The base class that caches precomputed cos/sin tables indexed by sequence position. The forward method applies rotary embeddings in-place using platform-specific kernels: rotary_emb.apply_rotary on CUDA (from flash-attention) or ops.rotary_embedding on ROCm (from vLLM). It provides two classmethods:

  • static: Creates an embedding from a model config, automatically dispatching to the appropriate subclass based on the rope_scaling config (linear, dynamic, yarn, su/longrope).
  • load: Creates an embedding by loading inv_freq weights from a checkpoint.

The get_cos_sin method returns indexed cos/sin tensors for a batch of position IDs, with ROCm using float32 and CUDA matching the input dtype.

SuRotaryEmbedding: Implements the SuRoPE / LongRoPE scaling strategy, which uses separate short_inv_freq and long_inv_freq frequency tables. It selects the long frequencies when the sequence length exceeds original_max_position_embeddings.

DynamicPositionRotaryEmbedding: Implements dynamic NTK-aware scaling that adjusts the base frequency when the sequence exceeds max_position_embeddings, using the formula: newbase = base * ((factor * seqlen / max_pos) - (factor - 1)) ** (dim / (dim - 2)).

YarnPositionRotaryEmbedding: Implements the YaRN (Yet another RoPE extensioN) method, which blends interpolated and extrapolated frequencies using a linear ramp mask between correction bounds (beta_fast and beta_slow). Applies magnitude scaling via get_mscale.

Helper functions include _create_inv_freq for generating inverse frequency tensors, _get_rope_config for reading scaling config from the model or environment variables, find_correction_dim/range for YaRN correction bounds, linear_ramp_mask for YaRN interpolation blending, and get_mscale for magnitude scaling.

Usage

This module is used by all transformer model implementations in LoRAX that employ rotary positional embeddings. The PositionRotaryEmbedding.static factory is the primary entry point, automatically selecting the correct scaling strategy based on the model configuration.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/layers/rotary.py
  • Lines: 1-356

Signature

class PositionRotaryEmbedding(nn.Module):
    def __init__(self, inv_freq, scaling_factor):

class SuRotaryEmbedding(PositionRotaryEmbedding):
    def __init__(self, short_inv_freq, long_inv_freq, scaling_factor, original_max_position_embeddings):

class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
    def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):

class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
    def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *,
                 extrapolation_factor, attn_factor, beta_fast, beta_slow):

Import

from lorax_server.layers.rotary import PositionRotaryEmbedding

I/O Contract

Inputs (forward)

Name Type Required Description
query torch.Tensor Yes Query tensor to apply rotary embeddings to
key torch.Tensor Yes Key tensor to apply rotary embeddings to
cos torch.Tensor Yes Precomputed cosine values for the positions
sin torch.Tensor Yes Precomputed sine values for the positions

Inputs (get_cos_sin)

Name Type Required Description
position_ids torch.Tensor Yes Position indices for the batch
max_s int Yes Maximum sequence length for cache update
dtype torch.dtype Yes Data type for cos/sin tensors

Outputs (get_cos_sin)

Name Type Description
cos torch.Tensor Cosine values indexed by position_ids, unsqueezed at dim 1
sin torch.Tensor Sine values indexed by position_ids, unsqueezed at dim 1

Usage Examples

# Used internally by transformer attention layers
from lorax_server.layers.rotary import PositionRotaryEmbedding

rotary_emb = PositionRotaryEmbedding.static(config, dim=128, base=10000, device=device)
cos, sin = rotary_emb.get_cos_sin(position_ids, max_s=2048, dtype=torch.float16)
rotary_emb(query, key, cos, sin)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment