Implementation:Turboderp org Exllamav2 ExLlamaV2ParallelDecoder
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Decoder |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
ExLlamaV2ParallelDecoder is a decoder block that processes the attention and MLP branches in parallel rather than sequentially, combining their outputs via residual addition.
Description
This class implements a parallel decoder layer architecture where a single shared layernorm feeds both an attention sublayer and an MLP sublayer simultaneously. The normalized hidden states are cloned and routed to each branch independently. Both branch outputs are then added back to the original residual stream. This contrasts with standard sequential transformer blocks where the MLP receives the post-attention residual. The parallel design is used by architectures such as GPT-J and GPT-NeoX.
The module supports both RMSNorm and LayerNorm for the shared input normalization, configurable via archparams. It delegates LoRA updates, quantization checks, device placement, and rank reduction to its child attention and MLP submodules.
When intermediates is requested, the forward_interm() method captures detailed per-branch outputs including post_norm, attn_output, pre_down, hidden_states_attn, and hidden_states_mlp for debugging or analysis.
Usage
Use ExLlamaV2ParallelDecoder when loading a model architecture that specifies parallel attention+MLP blocks (e.g., GPT-J style). It is instantiated automatically by the model loader based on the architecture configuration. Users do not typically instantiate this class directly but interact with it through the model's layer stack.
Code Reference
Source Location
- Repository: Turboderp_org_Exllamav2
- File: exllamav2/parallel_decoder.py
- Lines: 1-191
Signature
class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
name: str = "ParallelDecoder"
layer_idx: int
input_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None
attn: ExLlamaV2Attention
mlp: ExLlamaV2MLP
def __init__(
self,
model: ExLlamaV2,
key: str,
layer_idx: int,
sliding_window: int = 0,
archparams=None,
rope_index: int = 0,
): ...
def forward(
self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
attn_params: ExLlamaV2Attention.Params | None = None,
past_len: int | None = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]: ...
def forward_interm(
self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
attn_params: ExLlamaV2Attention.Params | None = None,
past_len: int | None = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]: ...
def load(self): ...
def unload(self): ...
def update_loras(self): ...
def numel(self) -> int: ...
def is_quant(self) -> bool: ...
def rank_reduce(self, k: float): ...
Import
from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
I/O Contract
forward()
| Parameter | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor |
Input hidden states tensor of shape (batch, seq_len, hidden_size)
|
| cache | None | KV cache for attention, or None for prompt processing |
| attn_params | None | Attention parameters including past lengths and position offsets |
| past_len | None | Number of past tokens already in the cache |
| intermediates | bool |
If True, return a dict with intermediate activations from both branches |
| loras | None | Optional list of LoRA adapters to apply |
| Return | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor |
Updated hidden states with both attention and MLP residuals added (when intermediates=False) |
| result dict | dict |
Keys: post_norm, attn_output, pre_down, hidden_states_attn, hidden_states_mlp, hidden_states (when intermediates=True) |
__init__()
| Parameter | Type | Description |
|---|---|---|
| model | ExLlamaV2 |
Parent model reference |
| key | str |
Weight key prefix for this layer |
| layer_idx | int |
Index of this layer in the model stack |
| sliding_window | int |
Sliding window size for attention (0 = disabled) |
| archparams | object |
Architecture-specific parameters (norm type, key mappings) |
| rope_index | int |
RoPE frequency index for multi-RoPE architectures |
Usage Examples
# Parallel decoder forward pass (within model execution)
# The parallel decoder is typically managed by ExLlamaV2 model internals
from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
# During model construction (handled internally):
decoder = ExLlamaV2ParallelDecoder(
model=model,
key="model.layers.0.",
layer_idx=0,
sliding_window=0,
)
decoder.load()
# Forward pass - attention and MLP run in parallel
output = decoder.forward(
hidden_states=hidden_states,
cache=cache,
attn_params=attn_params,
past_len=past_len,
)
# Forward pass with intermediates for inspection
result = decoder.forward(
hidden_states=hidden_states,
cache=cache,
attn_params=attn_params,
past_len=past_len,
intermediates=True,
)
attn_out = result["attn_output"]
mlp_pre_down = result["pre_down"]
Related Pages
- Turboderp_org_Exllamav2_ExLlamaV2RMSNorm - RMS normalization used as the shared input layernorm
- Turboderp_org_Exllamav2_RoPE - Rotary position embeddings used in the attention sublayer