Implementation:Predibase Lorax Flash Qwen2 Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Qwen2 transformer implementation for LoRax inference serving with flash attention, sliding window attention, and LoRA adapter support for both causal language modeling and embedding tasks.
Description
FlashQwen2ForCausalLM implements the Qwen2 architecture with flash attention for efficient batched inference. The module provides both a causal language model variant and an embedding model variant, making it one of the few model implementations in LoRax that supports dual-purpose serving.
The file contains seven classes organized as a layered architecture:
- Qwen2RMSNorm -- RMS normalization with a fused dropout_layer_norm kernel for hidden dimensions up to 8192, falling back to a manual implementation for larger dimensions. Returns both the normalized output and the residual connection.
- FlashQwen2Attention -- Multi-head attention with GQA support and bias on QKV projections, rotary position embeddings, flash attention for prefill and paged attention for decode. Uses custom adapter projection names (ATTN_Q_PROJ, ATTN_K_PROJ, ATTN_V_PROJ, ATTN_O_PROJ) for adapter-aware layers.
- Qwen2MLP -- Gated MLP with fused gate-up projections and adapter-aware gate/up/down projection layers using custom adapter names (MLP_GATE_PROJ, MLP_UP_PROJ, MLP_DOWN_PROJ).
- FlashQwen2Layer -- Single transformer decoder layer combining attention and MLP with pre-norm RMS normalization.
- FlashQwen2Model -- Full transformer model stacking N decoder layers with token embeddings and final normalization.
- FlashQwen2ForCausalLM -- Top-level causal language model with sliding window attention support, clamping max_s and seqlen during decode when max_past is set.
- FlashQwen2ForEmbeddings -- Embedding extraction variant that mean-pools hidden states across the sequence dimension and applies a linear projection, producing fixed-dimensional embeddings.
The implementation supports FP8 KV cache quantization and tensor parallelism for multi-GPU serving.
Usage
Used internally by the LoRax server when serving Qwen2-based models for both text generation and embedding tasks. Loaded via the model registry when the model config type matches.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py
- Lines: 1-605
Signature
class FlashQwen2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_qwen2_modeling import FlashQwen2ForCausalLM
from lorax_server.models.custom_modeling.flash_qwen2_modeling import FlashQwen2ForEmbeddings
I/O Contract
Inputs (FlashQwen2ForCausalLM)
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs [batch_size, seq_len] |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill (None during decode) |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors per layer |
| block_tables | torch.Tensor | Yes | Block table indices for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache placement |
| seqlen | Seqlen | Yes | Sequence length metadata wrapper |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter weights and indices for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective KV cache population during prefill; triggers slot slicing for sliding window |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head output |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs (FlashQwen2ForCausalLM)
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits [batch_size, vocab_size] (or hidden states if skip_lm_head is True) |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits from the multi-adapter head, or None |
Outputs (FlashQwen2ForEmbeddings)
| Name | Type | Description |
|---|---|---|
| embeddings | torch.Tensor | Mean-pooled and linearly projected embeddings [batch_size, output_dim] |
| None | None | Always None (no speculative logits for embedding mode) |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_qwen2_modeling import FlashQwen2ForCausalLM
# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration