Implementation:Predibase Lorax Flash RoBERTa
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides the high-level XLM-RoBERTa model wrapper for GPU-accelerated embedding inference with LoRA adapter support, composing flash-attention RoBERTa layers into a complete model integrated with the LoRax server's Model base class.
Description
This module defines the server-level wrappers for XLM-RoBERTa embedding models with full adapter loading support.
Key classes:
- RobertaEncoder - Stacks multiple
RobertaLayerinstances and passesadapter_datathrough to each layer for LoRA adapter application.
- FlashRobertaModel (extends
torch.nn.Module) - CombinesRobertaEmbeddingsandRobertaEncoder. Returns the first token of each sequence (CLS token extraction viacu_seqlens[:-1]indexing) for embedding output. Acceptsadapter_dataparameter for LoRA adapter integration.
- FlashXlmRoberta (extends
Model) - The main server-level wrapper that:- Enables adapter loading (
supports_adapter_loading = True) with Q, K, V attention adapter layers. - Supports merged adapter weight files via
create_merged_weight_files. - Uses
XLMRobertaConfigfor model configuration. - Implements
adapter_target_to_layermapping attention projections (ATTN_Q,ATTN_K,ATTN_V) to their corresponding model weight tensors. - Reports
supports_embeddings = True,supports_text_generation = False. embed- Creates adapter data from batch metadata, runs model forward within FlashInfer context, extracts CLS embeddings, and returns CPU results.- Manages FlashInfer prefill state for efficient attention computation.
- Enables adapter loading (
Usage
FlashXlmRoberta is instantiated by the LoRax model registry when loading XLM-RoBERTa models for embedding tasks. It is the only embedding model wrapper that supports dynamic LoRA adapter loading, allowing fine-tuned embedding models to be served with adapter swapping.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/flash_roberta.py - Lines: 1-233
Signature
class RobertaEncoder:
def __init__(self, prefix, weights, device, dtype, config):
...
def forward(self, hidden_states, cu_seqlens, max_s, adapter_data):
...
class FlashRobertaModel(torch.nn.Module):
def __init__(self, prefix, weights, device, dtype, config):
...
def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s, adapter_data):
...
class FlashXlmRoberta(Model):
def __init__(
self,
model_id: str,
adapter_id: str,
adapter_source: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
merge_adapter_weights: bool = False,
):
...
@property
def batch_type(self) -> Type[FlashEmbeddingClassificationBatch]:
...
@property
def supports_adapter_loading(self) -> bool:
...
@property
def supports_embeddings(self) -> bool:
...
def adapter_target_to_layer(self) -> dict[str, tuple[str, torch.Tensor]]:
...
@property
def adapter_layers(self) -> list[str]:
...
@property
def default_traced_adapter_layers(self) -> list[str]:
...
def get_num_layers_for_type(self, layer_type: str) -> int:
...
def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding:
...
Import
from lorax_server.models.flash_roberta import FlashXlmRoberta
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_id | str | Yes | HuggingFace model identifier for an XLM-RoBERTa model |
| adapter_id | str | Yes | Adapter identifier for LoRA loading |
| adapter_source | str | Yes | Source of adapter weights (e.g., "hub") |
| revision | Optional[str] | No | Model revision/commit hash |
| dtype | Optional[torch.dtype] | No | Model dtype (defaults to float16 on GPU) |
| merge_adapter_weights | bool | No | Whether to merge adapter weights into base model |
| batch | FlashEmbeddingClassificationBatch | Yes | Batch containing input_ids, token_type_ids, position_ids, cu_seqlens, max_s |
Outputs
| Name | Type | Description |
|---|---|---|
| embeddings | List[List[float]] | CLS token embeddings for each input in the batch |
Usage Examples
# Internal LoRax server usage
from lorax_server.models.flash_roberta import FlashXlmRoberta
# Instantiated by model registry for XLM-RoBERTa embedding models
# model = FlashXlmRoberta(
# model_id="xlm-roberta-base",
# adapter_id="my-lora-adapter",
# adapter_source="hub",
# )