Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash BERT

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides the high-level BERT model wrapper for GPU-accelerated embedding and token classification inference, composing flash-attention BERT layers into a complete model and integrating with the LoRax server's Model base class.

Description

This module defines the server-level wrappers that compose the low-level flash BERT layers from flash_bert_modeling.py into complete models suitable for embedding generation and token classification.

Key classes:

  • BertEncoder - Stacks multiple BertLayer instances and runs sequential forward through all layers.
  • FlashBertModel (extends torch.nn.Module) - Combines BertEmbeddings and BertEncoder. Returns the first token of each sequence (CLS token extraction via cu_seqlens[:-1] indexing) for embedding output.
  • FlashBertModelForClassification (extends torch.nn.Module) - Similar to FlashBertModel but adds a linear classification head on top. Returns logits of shape (batch_size, max_s, num_labels).
  • FlashBert (extends Model) - The main server-level wrapper that:
    • Initializes distributed processing with initialize_torch_distributed.
    • Loads weights from safetensors files via the Weights utility.
    • Handles model prefix differences (e.g., "bert" vs None for specific model IDs like "WhereIsAI/UAE-Large-V1").
    • Supports both embedding mode and classification mode via classifcation_head parameter.
    • Reports supports_embeddings = True, supports_text_generation = False, requires_block_allocator = False.
    • Manages FlashInfer prefill state for attention computation.
    • embed - Runs model forward within FlashInfer context, extracts CLS embeddings, reshapes to hidden size, and returns CPU results.
    • classify - Runs model forward, applies softmax, returns predicted token classes with confidence scores.
    • warmup - Performs a warmup forward pass and returns max_s.

Usage

FlashBert is instantiated by the LoRax model registry when loading BERT-family models for embedding or token classification tasks. It does not support text generation. The model uses flash attention for efficient variable-length sequence processing without padding overhead.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/flash_bert.py
  • Lines: 1-217

Signature

class BertEncoder:
    def __init__(self, prefix, weights, device, dtype, config: BertConfig):
        ...
    def forward(self, hidden_states, cu_seqlens, max_s):
        ...

class FlashBertModel(torch.nn.Module):
    def __init__(self, prefix, weights, device, dtype, config: BertConfig):
        ...
    def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
        ...

class FlashBertModelForClassification(torch.nn.Module):
    def __init__(self, prefix, weights, device, dtype, config: BertConfig):
        ...
    def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
        ...

class FlashBert(Model):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        classifcation_head: bool = False,
    ):
        ...
    @property
    def batch_type(self) -> Type[FlashEmbeddingClassificationBatch]:
        ...
    @property
    def supports_embeddings(self) -> bool:
        ...
    @property
    def supports_text_generation(self) -> bool:
        ...
    def warmup(self, batch, max_new_tokens) -> int | None:
        ...
    def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding:
        ...
    def classify(self, batch: FlashEmbeddingClassificationBatch):
        ...

Import

from lorax_server.models.flash_bert import FlashBert

I/O Contract

Inputs

Name Type Required Description
model_id str Yes HuggingFace model identifier for a BERT model
revision Optional[str] No Model revision/commit hash
dtype Optional[torch.dtype] No Model dtype (defaults to float16 on GPU)
classifcation_head bool No Whether to include classification head (default: False)
batch FlashEmbeddingClassificationBatch Yes Batch containing input_ids, token_type_ids, position_ids, cu_seqlens, max_s

Outputs

Name Type Description
embeddings List[List[float]] CLS token embeddings for each input in the batch (embed mode)
predicted_token_class List[List[str]] Predicted class labels per token (classify mode)
confidence_scores List[List[float]] Confidence scores for predicted classes (classify mode)

Usage Examples

# Internal LoRax server usage
from lorax_server.models.flash_bert import FlashBert

# Instantiated by model registry for BERT embedding models
# model = FlashBert(model_id="BAAI/bge-base-en-v1.5")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment