Implementation:Predibase Lorax Flash BERT
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides the high-level BERT model wrapper for GPU-accelerated embedding and token classification inference, composing flash-attention BERT layers into a complete model and integrating with the LoRax server's Model base class.
Description
This module defines the server-level wrappers that compose the low-level flash BERT layers from flash_bert_modeling.py into complete models suitable for embedding generation and token classification.
Key classes:
- BertEncoder - Stacks multiple
BertLayerinstances and runs sequential forward through all layers.
- FlashBertModel (extends
torch.nn.Module) - CombinesBertEmbeddingsandBertEncoder. Returns the first token of each sequence (CLS token extraction viacu_seqlens[:-1]indexing) for embedding output.
- FlashBertModelForClassification (extends
torch.nn.Module) - Similar toFlashBertModelbut adds a linear classification head on top. Returns logits of shape(batch_size, max_s, num_labels).
- FlashBert (extends
Model) - The main server-level wrapper that:- Initializes distributed processing with
initialize_torch_distributed. - Loads weights from safetensors files via the
Weightsutility. - Handles model prefix differences (e.g., "bert" vs None for specific model IDs like "WhereIsAI/UAE-Large-V1").
- Supports both embedding mode and classification mode via
classifcation_headparameter. - Reports
supports_embeddings = True,supports_text_generation = False,requires_block_allocator = False. - Manages FlashInfer prefill state for attention computation.
embed- Runs model forward within FlashInfer context, extracts CLS embeddings, reshapes to hidden size, and returns CPU results.classify- Runs model forward, applies softmax, returns predicted token classes with confidence scores.warmup- Performs a warmup forward pass and returnsmax_s.
- Initializes distributed processing with
Usage
FlashBert is instantiated by the LoRax model registry when loading BERT-family models for embedding or token classification tasks. It does not support text generation. The model uses flash attention for efficient variable-length sequence processing without padding overhead.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/flash_bert.py - Lines: 1-217
Signature
class BertEncoder:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
...
def forward(self, hidden_states, cu_seqlens, max_s):
...
class FlashBertModel(torch.nn.Module):
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
...
def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
...
class FlashBertModelForClassification(torch.nn.Module):
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
...
def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
...
class FlashBert(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
classifcation_head: bool = False,
):
...
@property
def batch_type(self) -> Type[FlashEmbeddingClassificationBatch]:
...
@property
def supports_embeddings(self) -> bool:
...
@property
def supports_text_generation(self) -> bool:
...
def warmup(self, batch, max_new_tokens) -> int | None:
...
def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding:
...
def classify(self, batch: FlashEmbeddingClassificationBatch):
...
Import
from lorax_server.models.flash_bert import FlashBert
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_id | str | Yes | HuggingFace model identifier for a BERT model |
| revision | Optional[str] | No | Model revision/commit hash |
| dtype | Optional[torch.dtype] | No | Model dtype (defaults to float16 on GPU) |
| classifcation_head | bool | No | Whether to include classification head (default: False) |
| batch | FlashEmbeddingClassificationBatch | Yes | Batch containing input_ids, token_type_ids, position_ids, cu_seqlens, max_s |
Outputs
| Name | Type | Description |
|---|---|---|
| embeddings | List[List[float]] | CLS token embeddings for each input in the batch (embed mode) |
| predicted_token_class | List[List[str]] | Predicted class labels per token (classify mode) |
| confidence_scores | List[List[float]] | Confidence scores for predicted classes (classify mode) |
Usage Examples
# Internal LoRax server usage
from lorax_server.models.flash_bert import FlashBert
# Instantiated by model registry for BERT embedding models
# model = FlashBert(model_id="BAAI/bge-base-en-v1.5")