Implementation:Predibase Lorax Flash BERT Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements flash-attention-optimized BERT and DistilBERT model layers (embeddings, attention, and transformer blocks) used for efficient GPU-based embedding inference within the LoRax server.
Description
This module provides low-level PyTorch layer implementations for both BERT and DistilBERT architectures, optimized with flash attention for high-throughput embedding generation. The implementation is based on the HuggingFace text-embeddings-inference flash BERT backend. Note that these classes do not extend nn.Module; they are plain Python classes that manually manage weight tensors.
DistilBERT Components:
- DistilBertEmbeddings - Combines word embeddings and position embeddings with layer normalization (no token type embeddings, matching DistilBERT architecture).
- DistilBertAttention - Fused QKV attention using flash attention kernel. Concatenates query, key, and value weights into a single matrix for efficient computation. Uses
FastLayerNormfor post-attention normalization with residual connection.
- DistilBertLayer - Full transformer block combining attention and feed-forward network with GELU activation and layer normalization.
BERT Components:
- BertEmbeddings - Word, token type, and absolute position embeddings with layer normalization. Only supports absolute position embeddings.
- BertAttention - Fused QKV attention using flash attention kernel with non-causal masking. Identical structure to DistilBERT attention but with BERT-specific weight naming.
- BertLayer - Full transformer block with attention and intermediate/output feed-forward layers, each with their own layer normalization and residual connections.
All attention classes use torch.addmm for fused bias-add-matmul operations and cu_seqlens-based flash attention for variable-length sequence processing without padding.
Usage
These layer classes are composed by the higher-level FlashBertModel in flash_bert.py to build complete BERT or DistilBERT models for embedding generation. They are not used directly but are imported and stacked into encoder architectures.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/custom_modeling/flash_bert_modeling.py - Lines: 1-212
Signature
class DistilBertEmbeddings:
def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
...
def forward(self, input_ids, token_type_ids, position_ids):
...
class DistilBertAttention:
def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
...
def forward(self, hidden_states, cu_seqlens, max_s):
...
class DistilBertLayer:
def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
...
def forward(self, hidden_states, cu_seqlens, max_s):
...
class BertEmbeddings:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
...
def forward(self, input_ids, token_type_ids, position_ids):
...
class BertAttention:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
...
def forward(self, hidden_states, cu_seqlens, max_s):
...
class BertLayer:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
...
def forward(self, hidden_states, cu_seqlens, max_s):
...
Import
from lorax_server.models.custom_modeling.flash_bert_modeling import BertEmbeddings, BertLayer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs for the input sequence (flattened across batch) |
| token_type_ids | torch.Tensor | Yes | Token type IDs (segment IDs for BERT, unused for DistilBERT) |
| position_ids | torch.Tensor | Yes | Position indices for each token |
| cu_seqlens | torch.Tensor | Yes | Cumulative sequence lengths for flash attention |
| max_s | int | Yes | Maximum sequence length in the batch |
Outputs
| Name | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor | Encoded hidden states for each token position |
Usage Examples
# Internal LoRax server usage - composed by FlashBertModel
from lorax_server.models.custom_modeling.flash_bert_modeling import BertEmbeddings, BertLayer
# Layers are instantiated within FlashBertModel.__init__
# layer = BertLayer(f"bert.encoder.layer.{i}", weights, device, dtype, config)