Overview
Concrete tool for running SigLIP vision encoder layers with W8A8 (8-bit weight, 8-bit activation) quantization, pre-allocated activation buffers, and flash attention support.
Description
QuantSiglipEncoder wraps the standard SiglipEncoder layers with quantized counterparts and manages shared ActivationBuffer objects for efficient memory reuse across layers. QuantSiglipFlashAttention2 fuses Q, K, V projections into a single W8A8OF16LinearDynamicInputScale layer and uses flash_attn_func for efficient attention computation. QuantSiglipMLP quantizes the feed-forward layers with fused GELU activation and quantization via awq_inference_engine.gelu_and_quant. QuantSiglipEncoderLayer composes attention, MLP, and layer normalization into a full quantized encoder block with residual connections. RMSNormGeneral implements RMS normalization with integrated per-token or per-tensor quantization output using the awq_inference_engine.rms_norm_general kernel. The constant CLIP_RANGE = 5 defines the activation clipping range used during quantization.
Usage
Import QuantSiglipEncoder when deploying SigLIP-based vision-language models (e.g., LLaVA variants) with W8A8 vision encoder quantization for reduced memory footprint and faster inference throughput.
Code Reference
Source Location
Signature
CLIP_RANGE = 5
class QuantSiglipEncoder(nn.Module):
def __init__(self, module: SiglipEncoder, bsz=64, seqlen=1024):
"""Wrap SiglipEncoder layers with quantized versions and allocate activation buffers."""
def forward(self, inputs_embeds, attention_mask=None, output_attentions=None,
output_hidden_states=None, return_dict=None) -> BaseModelOutput: ...
class QuantSiglipFlashAttention2(nn.Module):
def __init__(self, module: SiglipAttention, init_only=False): ...
def forward(self, buffer: ActivationBuffer, bsz=64, seqlen=1024)
-> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...
class QuantSiglipMLP(nn.Module):
def __init__(self, siglipmlp, init_only=False): ...
def forward(self, buffer: ActivationBuffer) -> torch.Tensor: ...
class QuantSiglipEncoderLayer(nn.Module):
def __init__(self, module: SiglipEncoderLayer): ...
def forward(self, hidden_states: torch.Tensor, buffer: ActivationBuffer,
attention_mask, bsz, seqlen) -> Tuple[torch.FloatTensor]: ...
class RMSNormGeneral(nn.Module):
def __init__(self, weight: torch.tensor, bias: torch.tensor,
eps: float = 1e-6, use_per_token_quant: bool = True): ...
def forward(self, x: torch.Tensor, quantized_hidden_states_buffer: torch.Tensor,
quantized_scale_buffer: torch.Tensor,
quantized_sum_buffer: torch.Tensor = None) -> torch.Tensor: ...
Import
from tinychat.modules.fused_siglipdecoder import QuantSiglipEncoder
I/O Contract
Inputs
| Name |
Type |
Required |
Description
|
| module |
SiglipEncoder |
Yes |
Pre-trained SigLIP encoder to quantize
|
| bsz |
int |
No |
Maximum batch size for buffer allocation (default: 64)
|
| seqlen |
int |
No |
Maximum sequence length for buffer allocation (default: 1024)
|
| inputs_embeds |
torch.Tensor |
Yes |
Input embeddings from vision patch encoding, shape (batch, seq_len, hidden_size)
|
| attention_mask |
torch.Tensor |
No |
Optional attention mask
|
Outputs
| Name |
Type |
Description
|
| forward returns |
BaseModelOutput |
Encoder output with last_hidden_state, optional hidden_states tuple
|
Internal Buffers
| Name |
Description
|
| ActivationBuffer |
Shared pre-allocated buffer for intermediate activations (quantized hidden states, scales, QKV projections, MLP intermediates). Automatically reallocated if batch size or sequence length changes.
|
Usage Examples
Quantize SigLIP Encoder
from tinychat.modules.fused_siglipdecoder import QuantSiglipEncoder
# Wrap pre-trained SigLIP encoder with quantized version
quant_encoder = QuantSiglipEncoder(
model.vision_tower.vision_model.encoder, bsz=1, seqlen=729
)
# Replace in model
model.vision_tower.vision_model.encoder = quant_encoder
# Forward pass (same interface as original encoder)
output = quant_encoder(pixel_embeddings)
last_hidden = output.last_hidden_state
Access Clipping Range
from tinychat.modules.fused_siglipdecoder import CLIP_RANGE
# CLIP_RANGE = 5 defines the activation clipping range for quantization
print(f"Activation clipping range: [-{CLIP_RANGE}, {CLIP_RANGE}]")
Related Pages