Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax OPT Modeling

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides a tensor-parallel implementation of the OPT (Open Pre-trained Transformer) causal language model with learned positional embeddings for inference within the LoRAX serving framework.

Description

This module implements Meta AI's OPT architecture, a decoder-only transformer with learned positional embeddings. OPT uses a standard attention mechanism with separate Q, K, V projections rather than fused QKV.

Key classes:

  • OPTLearnedPositionalEmbedding -- Learned positional embeddings with a fixed offset of 2. Computes positions from attention masks to handle padding correctly, supporting incremental decoding with past_key_values_length.
  • OPTAttention -- Multi-headed self-attention with separate query, key, and value projections via TensorParallelColumnLinear, and an output projection via TensorParallelRowLinear. Supports decoder self-attention with KV caching.
  • OPTDecoderLayer -- Single decoder layer combining self-attention, two feed-forward layers (fc1 as column-parallel, fc2 as row-parallel), and layer normalization. Supports both pre-norm and post-norm configurations via do_layer_norm_before.
  • OPTDecoder -- Full decoder stack with token embeddings (TensorParallelEmbedding), learned positional embeddings, optional project-in/project-out linear layers (for models where word_embed_proj_dim != hidden_size), the decoder layer stack, and optional final layer norm.
  • OPTModel -- Wraps the OPTDecoder and returns BaseModelOutputWithPast.
  • OPTForCausalLM -- Top-level causal LM wrapper with OPTModel and an LM head (TensorParallelHead). Returns CausalLMOutputWithPast.

Usage

Used internally by the LoRAX server when serving OPT-based models (e.g., facebook/opt-125m through opt-66b). Loaded via the model registry.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/custom_modeling/opt_modeling.py
  • Lines: 1-736

Signature

class OPTForCausalLM(OPTPreTrainedModel):
    def __init__(self, config, weights):
        ...

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

Import

from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor No Input token IDs of shape (batch_size, sequence_length)
attention_mask Optional[torch.Tensor] No Attention mask for padding tokens
head_mask Optional[torch.Tensor] No Per-head attention mask
past_key_values Optional[List[torch.FloatTensor]] No Cached KV states for autoregressive decoding
inputs_embeds Optional[torch.FloatTensor] No Pre-computed input embeddings
labels Optional[torch.LongTensor] No Labels for language modeling loss
use_cache Optional[bool] No Whether to return KV cache

Outputs

Name Type Description
loss Optional[torch.Tensor] Language modeling loss (when labels provided)
logits torch.Tensor Prediction logits of shape (batch_size, sequence_length, vocab_size)
past_key_values Tuple Cached KV states for subsequent decoding
hidden_states Optional[Tuple[torch.Tensor]] Hidden states from all layers
attentions Optional[Tuple[torch.Tensor]] Attention weights from all layers

Usage Examples

# Internal usage within LoRAX server
from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM
# Instantiated by model registry with OPTConfig and pre-loaded weights

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment