Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax NeoX Modeling

From Leeroopedia
Revision as of 16:21, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Predibase_Lorax_NeoX_Modeling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides a tensor-parallel implementation of the GPT-NeoX causal language model with rotary positional embeddings and optional fused attention CUDA kernels for inference within the LoRAX serving framework.

Description

This module implements the GPT-NeoX architecture (EleutherAI) with tensor parallelism support. GPT-NeoX uses rotary positional embeddings (RoPE) applied to a portion of the attention head dimensions, and parallel attention+MLP computation within each transformer block.

Key classes:

  • GPTNeoXAttention -- Multi-head self-attention with partial rotary positional embeddings. The rotary_ndims parameter controls what fraction of head dimensions receive rotary encoding (determined by config.rotary_pct). Uses a fused query_key_value projection via TensorParallelColumnLinear and a dense output via TensorParallelRowLinear. When custom kernels are available, delegates to fused_attention_cuda.
  • RotaryEmbedding -- Implements RoPE (Rotary Position Embedding) with cached cosine/sine tensors. Supports dynamic sequence length extension beyond the initial max_position_embeddings.
  • GPTNeoXMLP -- Two-layer feed-forward network with configurable activation (typically GeLU). Uses TensorParallelColumnLinear for the up-projection (dense_h_to_4h) and TensorParallelRowLinear for the down-projection (dense_4h_to_h).
  • GPTNeoXLayer -- Single transformer block that can run attention and MLP in parallel (when use_parallel_residual=True) or sequentially. Includes two layer norms: input_layernorm and post_attention_layernorm.
  • GPTNeoXModel -- Full transformer backbone with token embeddings (TensorParallelEmbedding), the layer stack, and a final layer norm. Constructs causal masks and manages position IDs.
  • GPTNeoxForCausalLM -- Top-level causal LM wrapper with GPTNeoXModel and an LM head (TensorParallelHead). Returns CausalLMOutputWithPast.

Usage

Used internally by the LoRAX server when serving GPT-NeoX-based models (e.g., EleutherAI/gpt-neox-20b, pythia series). Loaded via the model registry.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/custom_modeling/neox_modeling.py
  • Lines: 1-717

Signature

class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
    def __init__(self, config, weights):
        ...

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

Import

from lorax_server.models.custom_modeling.neox_modeling import GPTNeoxForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids Optional[torch.LongTensor] No Input token IDs of shape (batch_size, sequence_length)
attention_mask Optional[torch.FloatTensor] No Attention mask for padding
position_ids Optional[torch.LongTensor] No Position IDs for rotary embeddings
inputs_embeds Optional[torch.FloatTensor] No Pre-computed input embeddings
head_mask Optional[torch.FloatTensor] No Per-head attention mask
past_key_values Optional[Tuple[Tuple[torch.FloatTensor]]] No Cached KV states for autoregressive decoding
labels Optional[torch.LongTensor] No Labels for language modeling loss
use_cache Optional[bool] No Whether to return KV cache

Outputs

Name Type Description
loss Optional[torch.Tensor] Language modeling loss (when labels provided)
logits torch.Tensor Prediction logits of shape (batch_size, sequence_length, vocab_size)
past_key_values Tuple[Tuple[torch.Tensor]] Cached KV states for subsequent decoding
hidden_states Optional[Tuple[torch.Tensor]] Hidden states from all layers
attentions Optional[Tuple[torch.Tensor]] Attention weights from all layers

Usage Examples

# Internal usage within LoRAX server
from lorax_server.models.custom_modeling.neox_modeling import GPTNeoxForCausalLM
# Instantiated by model registry with GPT-NeoX config and pre-loaded weights

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment