Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash BERT Modeling

From Leeroopedia
Revision as of 16:20, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Predibase_Lorax_Flash_BERT_Modeling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Implements flash-attention-optimized BERT and DistilBERT model layers (embeddings, attention, and transformer blocks) used for efficient GPU-based embedding inference within the LoRax server.

Description

This module provides low-level PyTorch layer implementations for both BERT and DistilBERT architectures, optimized with flash attention for high-throughput embedding generation. The implementation is based on the HuggingFace text-embeddings-inference flash BERT backend. Note that these classes do not extend nn.Module; they are plain Python classes that manually manage weight tensors.

DistilBERT Components:

  • DistilBertEmbeddings - Combines word embeddings and position embeddings with layer normalization (no token type embeddings, matching DistilBERT architecture).
  • DistilBertAttention - Fused QKV attention using flash attention kernel. Concatenates query, key, and value weights into a single matrix for efficient computation. Uses FastLayerNorm for post-attention normalization with residual connection.
  • DistilBertLayer - Full transformer block combining attention and feed-forward network with GELU activation and layer normalization.

BERT Components:

  • BertEmbeddings - Word, token type, and absolute position embeddings with layer normalization. Only supports absolute position embeddings.
  • BertAttention - Fused QKV attention using flash attention kernel with non-causal masking. Identical structure to DistilBERT attention but with BERT-specific weight naming.
  • BertLayer - Full transformer block with attention and intermediate/output feed-forward layers, each with their own layer normalization and residual connections.

All attention classes use torch.addmm for fused bias-add-matmul operations and cu_seqlens-based flash attention for variable-length sequence processing without padding.

Usage

These layer classes are composed by the higher-level FlashBertModel in flash_bert.py to build complete BERT or DistilBERT models for embedding generation. They are not used directly but are imported and stacked into encoder architectures.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/custom_modeling/flash_bert_modeling.py
  • Lines: 1-212

Signature

class DistilBertEmbeddings:
    def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
        ...
    def forward(self, input_ids, token_type_ids, position_ids):
        ...

class DistilBertAttention:
    def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
        ...
    def forward(self, hidden_states, cu_seqlens, max_s):
        ...

class DistilBertLayer:
    def __init__(self, prefix, weights, device, dtype, config: DistilBertConfig):
        ...
    def forward(self, hidden_states, cu_seqlens, max_s):
        ...

class BertEmbeddings:
    def __init__(self, prefix, weights, device, dtype, config: BertConfig):
        ...
    def forward(self, input_ids, token_type_ids, position_ids):
        ...

class BertAttention:
    def __init__(self, prefix, weights, device, dtype, config: BertConfig):
        ...
    def forward(self, hidden_states, cu_seqlens, max_s):
        ...

class BertLayer:
    def __init__(self, prefix, weights, device, dtype, config: BertConfig):
        ...
    def forward(self, hidden_states, cu_seqlens, max_s):
        ...

Import

from lorax_server.models.custom_modeling.flash_bert_modeling import BertEmbeddings, BertLayer

I/O Contract

Inputs

Name Type Required Description
input_ids torch.Tensor Yes Token IDs for the input sequence (flattened across batch)
token_type_ids torch.Tensor Yes Token type IDs (segment IDs for BERT, unused for DistilBERT)
position_ids torch.Tensor Yes Position indices for each token
cu_seqlens torch.Tensor Yes Cumulative sequence lengths for flash attention
max_s int Yes Maximum sequence length in the batch

Outputs

Name Type Description
hidden_states torch.Tensor Encoded hidden states for each token position

Usage Examples

# Internal LoRax server usage - composed by FlashBertModel
from lorax_server.models.custom_modeling.flash_bert_modeling import BertEmbeddings, BertLayer

# Layers are instantiated within FlashBertModel.__init__
# layer = BertLayer(f"bert.encoder.layer.{i}", weights, device, dtype, config)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment