Implementation:Predibase Lorax OPT Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a tensor-parallel implementation of the OPT (Open Pre-trained Transformer) causal language model with learned positional embeddings for inference within the LoRAX serving framework.
Description
This module implements Meta AI's OPT architecture, a decoder-only transformer with learned positional embeddings. OPT uses a standard attention mechanism with separate Q, K, V projections rather than fused QKV.
Key classes:
- OPTLearnedPositionalEmbedding -- Learned positional embeddings with a fixed offset of 2. Computes positions from attention masks to handle padding correctly, supporting incremental decoding with
past_key_values_length.
- OPTAttention -- Multi-headed self-attention with separate query, key, and value projections via
TensorParallelColumnLinear, and an output projection viaTensorParallelRowLinear. Supports decoder self-attention with KV caching.
- OPTDecoderLayer -- Single decoder layer combining self-attention, two feed-forward layers (
fc1as column-parallel,fc2as row-parallel), and layer normalization. Supports both pre-norm and post-norm configurations viado_layer_norm_before.
- OPTDecoder -- Full decoder stack with token embeddings (
TensorParallelEmbedding), learned positional embeddings, optional project-in/project-out linear layers (for models whereword_embed_proj_dim != hidden_size), the decoder layer stack, and optional final layer norm.
- OPTModel -- Wraps the
OPTDecoderand returnsBaseModelOutputWithPast.
- OPTForCausalLM -- Top-level causal LM wrapper with
OPTModeland an LM head (TensorParallelHead). ReturnsCausalLMOutputWithPast.
Usage
Used internally by the LoRAX server when serving OPT-based models (e.g., facebook/opt-125m through opt-66b). Loaded via the model registry.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/custom_modeling/opt_modeling.py - Lines: 1-736
Signature
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
...
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
Import
from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor |
No | Input token IDs of shape (batch_size, sequence_length)
|
| attention_mask | Optional[torch.Tensor] |
No | Attention mask for padding tokens |
| head_mask | Optional[torch.Tensor] |
No | Per-head attention mask |
| past_key_values | Optional[List[torch.FloatTensor]] |
No | Cached KV states for autoregressive decoding |
| inputs_embeds | Optional[torch.FloatTensor] |
No | Pre-computed input embeddings |
| labels | Optional[torch.LongTensor] |
No | Labels for language modeling loss |
| use_cache | Optional[bool] |
No | Whether to return KV cache |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Optional[torch.Tensor] |
Language modeling loss (when labels provided) |
| logits | torch.Tensor |
Prediction logits of shape (batch_size, sequence_length, vocab_size)
|
| past_key_values | Tuple |
Cached KV states for subsequent decoding |
| hidden_states | Optional[Tuple[torch.Tensor]] |
Hidden states from all layers |
| attentions | Optional[Tuple[torch.Tensor]] |
Attention weights from all layers |
Usage Examples
# Internal usage within LoRAX server
from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM
# Instantiated by model registry with OPTConfig and pre-loaded weights