Implementation:Predibase Lorax NeoX Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a tensor-parallel implementation of the GPT-NeoX causal language model with rotary positional embeddings and optional fused attention CUDA kernels for inference within the LoRAX serving framework.
Description
This module implements the GPT-NeoX architecture (EleutherAI) with tensor parallelism support. GPT-NeoX uses rotary positional embeddings (RoPE) applied to a portion of the attention head dimensions, and parallel attention+MLP computation within each transformer block.
Key classes:
- GPTNeoXAttention -- Multi-head self-attention with partial rotary positional embeddings. The
rotary_ndimsparameter controls what fraction of head dimensions receive rotary encoding (determined byconfig.rotary_pct). Uses a fusedquery_key_valueprojection viaTensorParallelColumnLinearand a dense output viaTensorParallelRowLinear. When custom kernels are available, delegates tofused_attention_cuda.
- RotaryEmbedding -- Implements RoPE (Rotary Position Embedding) with cached cosine/sine tensors. Supports dynamic sequence length extension beyond the initial
max_position_embeddings.
- GPTNeoXMLP -- Two-layer feed-forward network with configurable activation (typically GeLU). Uses
TensorParallelColumnLinearfor the up-projection (dense_h_to_4h) andTensorParallelRowLinearfor the down-projection (dense_4h_to_h).
- GPTNeoXLayer -- Single transformer block that can run attention and MLP in parallel (when
use_parallel_residual=True) or sequentially. Includes two layer norms:input_layernormandpost_attention_layernorm.
- GPTNeoXModel -- Full transformer backbone with token embeddings (
TensorParallelEmbedding), the layer stack, and a final layer norm. Constructs causal masks and manages position IDs.
- GPTNeoxForCausalLM -- Top-level causal LM wrapper with
GPTNeoXModeland an LM head (TensorParallelHead). ReturnsCausalLMOutputWithPast.
Usage
Used internally by the LoRAX server when serving GPT-NeoX-based models (e.g., EleutherAI/gpt-neox-20b, pythia series). Loaded via the model registry.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/custom_modeling/neox_modeling.py - Lines: 1-717
Signature
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
def __init__(self, config, weights):
...
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
Import
from lorax_server.models.custom_modeling.neox_modeling import GPTNeoxForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | Optional[torch.LongTensor] |
No | Input token IDs of shape (batch_size, sequence_length)
|
| attention_mask | Optional[torch.FloatTensor] |
No | Attention mask for padding |
| position_ids | Optional[torch.LongTensor] |
No | Position IDs for rotary embeddings |
| inputs_embeds | Optional[torch.FloatTensor] |
No | Pre-computed input embeddings |
| head_mask | Optional[torch.FloatTensor] |
No | Per-head attention mask |
| past_key_values | Optional[Tuple[Tuple[torch.FloatTensor]]] |
No | Cached KV states for autoregressive decoding |
| labels | Optional[torch.LongTensor] |
No | Labels for language modeling loss |
| use_cache | Optional[bool] |
No | Whether to return KV cache |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Optional[torch.Tensor] |
Language modeling loss (when labels provided) |
| logits | torch.Tensor |
Prediction logits of shape (batch_size, sequence_length, vocab_size)
|
| past_key_values | Tuple[Tuple[torch.Tensor]] |
Cached KV states for subsequent decoding |
| hidden_states | Optional[Tuple[torch.Tensor]] |
Hidden states from all layers |
| attentions | Optional[Tuple[torch.Tensor]] |
Attention weights from all layers |
Usage Examples
# Internal usage within LoRAX server
from lorax_server.models.custom_modeling.neox_modeling import GPTNeoxForCausalLM
# Instantiated by model registry with GPT-NeoX config and pre-loaded weights