Principle:FMInference FlexLLMGen Compression Aware Training Layers
| Field | Value |
|---|---|
| Sources | Paper: FlexGen, DeepSpeed Compression Documentation |
| Domains | Model_Compression, Quantization, Pruning, Neural_Network_Layers |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A training methodology where neural network layers are augmented with differentiable compression operators (quantization, pruning) that are applied during forward passes, enabling the model to learn representations that are robust to compression-induced information loss.
Description
Compression-aware training (also known as quantization-aware training and pruning-aware training) integrates compression operations directly into the training loop rather than applying them as a post-processing step. This allows gradients to flow through the compression operators, enabling the model to adapt its weights to compensate for compression artifacts.
Key principles:
- Straight-through estimator (STE) -- Quantization operations are non-differentiable (rounding). The STE passes gradients through the quantization step unchanged, allowing backpropagation to update the underlying full-precision weights while the forward pass uses quantized values.
- Progressive quantization (MoQ) -- Rather than immediately quantizing to the target bit-width, the system starts at a higher bit-width and gradually reduces over training. This is controlled by start_bits, target_bits, and quantization_period parameters, giving the model time to adapt at each precision level.
- Differentiable pruning masks -- Pruning can use either fixed masks (L1-norm) or learned masks (TopK with differentiable thresholding via TopKBinarizer). The TopK method treats the pruning mask as a learnable parameter, enabling gradient-based optimization of which weights to prune.
- Multi-granularity pruning -- Different pruning granularities serve different purposes:
- Sparse pruning -- Individual weight elements (unstructured).
- Row pruning -- Entire output neurons (structured).
- Head pruning -- Entire attention heads (structured, attention-specific).
- Channel pruning -- Entire convolutional filters (structured, CNN-specific).
- Activation quantization -- Input activations can be quantized using either static range (exponential moving average of observed min/max values) or dynamic range (per-batch min/max). Static range is faster at inference but less accurate for variable distributions.
- Composability -- Multiple compression techniques can be enabled simultaneously on the same layer. The forward pass applies them in sequence: weight quantization, sparse pruning, row/head pruning, activation quantization, then the linear/conv operation.
Usage
Use compression-aware training layers when the target deployment requires reduced model size or inference latency. Replace standard PyTorch layers with their _Compress variants, then enable the desired compression techniques with schedule offsets to introduce them gradually during training.
Theoretical Basis
For weight quantization with n bits and g groups:
For each group i of weights W_i:
q_i = round((W_i - min(W_i)) / (max(W_i) - min(W_i)) * (2^n - 1))
W_i_approx = q_i * (max(W_i) - min(W_i)) / (2^n - 1) + min(W_i)
For TopK pruning with density ratio r:
threshold = sorted(|scores|)[floor((1-r) * len(scores))]
mask = (|scores| > threshold).float()
The TopK binarizer uses a differentiable approximation where gradients flow through the threshold operation, enabling the pruning scores to be optimized via standard gradient descent.