Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash RoBERTa

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides the high-level XLM-RoBERTa model wrapper for GPU-accelerated embedding inference with LoRA adapter support, composing flash-attention RoBERTa layers into a complete model integrated with the LoRax server's Model base class.

Description

This module defines the server-level wrappers for XLM-RoBERTa embedding models with full adapter loading support.

Key classes:

  • RobertaEncoder - Stacks multiple RobertaLayer instances and passes adapter_data through to each layer for LoRA adapter application.
  • FlashRobertaModel (extends torch.nn.Module) - Combines RobertaEmbeddings and RobertaEncoder. Returns the first token of each sequence (CLS token extraction via cu_seqlens[:-1] indexing) for embedding output. Accepts adapter_data parameter for LoRA adapter integration.
  • FlashXlmRoberta (extends Model) - The main server-level wrapper that:
    • Enables adapter loading (supports_adapter_loading = True) with Q, K, V attention adapter layers.
    • Supports merged adapter weight files via create_merged_weight_files.
    • Uses XLMRobertaConfig for model configuration.
    • Implements adapter_target_to_layer mapping attention projections (ATTN_Q, ATTN_K, ATTN_V) to their corresponding model weight tensors.
    • Reports supports_embeddings = True, supports_text_generation = False.
    • embed - Creates adapter data from batch metadata, runs model forward within FlashInfer context, extracts CLS embeddings, and returns CPU results.
    • Manages FlashInfer prefill state for efficient attention computation.

Usage

FlashXlmRoberta is instantiated by the LoRax model registry when loading XLM-RoBERTa models for embedding tasks. It is the only embedding model wrapper that supports dynamic LoRA adapter loading, allowing fine-tuned embedding models to be served with adapter swapping.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/flash_roberta.py
  • Lines: 1-233

Signature

class RobertaEncoder:
    def __init__(self, prefix, weights, device, dtype, config):
        ...
    def forward(self, hidden_states, cu_seqlens, max_s, adapter_data):
        ...

class FlashRobertaModel(torch.nn.Module):
    def __init__(self, prefix, weights, device, dtype, config):
        ...
    def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s, adapter_data):
        ...

class FlashXlmRoberta(Model):
    def __init__(
        self,
        model_id: str,
        adapter_id: str,
        adapter_source: str,
        revision: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        merge_adapter_weights: bool = False,
    ):
        ...
    @property
    def batch_type(self) -> Type[FlashEmbeddingClassificationBatch]:
        ...
    @property
    def supports_adapter_loading(self) -> bool:
        ...
    @property
    def supports_embeddings(self) -> bool:
        ...
    def adapter_target_to_layer(self) -> dict[str, tuple[str, torch.Tensor]]:
        ...
    @property
    def adapter_layers(self) -> list[str]:
        ...
    @property
    def default_traced_adapter_layers(self) -> list[str]:
        ...
    def get_num_layers_for_type(self, layer_type: str) -> int:
        ...
    def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding:
        ...

Import

from lorax_server.models.flash_roberta import FlashXlmRoberta

I/O Contract

Inputs

Name Type Required Description
model_id str Yes HuggingFace model identifier for an XLM-RoBERTa model
adapter_id str Yes Adapter identifier for LoRA loading
adapter_source str Yes Source of adapter weights (e.g., "hub")
revision Optional[str] No Model revision/commit hash
dtype Optional[torch.dtype] No Model dtype (defaults to float16 on GPU)
merge_adapter_weights bool No Whether to merge adapter weights into base model
batch FlashEmbeddingClassificationBatch Yes Batch containing input_ids, token_type_ids, position_ids, cu_seqlens, max_s

Outputs

Name Type Description
embeddings List[List[float]] CLS token embeddings for each input in the batch

Usage Examples

# Internal LoRax server usage
from lorax_server.models.flash_roberta import FlashXlmRoberta

# Instantiated by model registry for XLM-RoBERTa embedding models
# model = FlashXlmRoberta(
#     model_id="xlm-roberta-base",
#     adapter_id="my-lora-adapter",
#     adapter_source="hub",
# )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment