Principle:Kornia Kornia Differentiable Training Augmentation
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Vision, Training |
| Last Updated | 2026-02-09 15:00 GMT |
Overview
Technique of integrating GPU-accelerated differentiable augmentations directly into the training loop with gradient preservation.
Description
Unlike CPU-based augmentation (e.g., torchvision transforms), differentiable augmentation runs on GPU and preserves gradients through the augmentation operations. This enables:
- Augmentation as part of the computation graph for adversarial training or augmentation optimization, where augmentation parameters can be learned jointly with the model.
- Significantly faster augmentation by leveraging GPU parallelism, avoiding CPU-to-GPU data transfer bottlenecks.
- Batch-consistent augmentation where the same random transform is optionally applied to the entire batch via the
same_on_batchparameter.
Because augmentation operations are implemented as differentiable PyTorch operations, the entire pipeline from augmented input to loss computation forms a single computation graph. This is fundamentally different from pre-computed offline augmentation or CPU-side transforms that break the gradient chain.
Usage
Use when:
- Augmentation speed is critical -- GPU augmentation is faster than CPU augmentation for batch processing
- Gradient flow through augmentation is needed -- required for adversarial training, differentiable data augmentation (DDA), or learned augmentation strategies
- Augmentation parameters must be learned during training -- the augmentation module's parameters can receive gradients
Theoretical Basis
Given loss L = l(f(T(x)), y), the gradient with respect to model parameters includes the augmentation:
# Standard case: gradient through augmented input
# dL/dtheta = dL/df * df/d(T(x))
# Learnable augmentation case: gradient also flows to augmentation params
# dL/dtheta_aug = dL/df * df/dT * dT/dtheta_aug
The .forward() method of AugmentationSequential preserves this gradient chain. Pre-computed params enable reproducible augmentation, where the same random parameters can be applied to multiple forward passes:
# Reproducible augmentation via pre-computed params
augmented_1 = aug(images, params=saved_params)
augmented_2 = aug(images, params=saved_params)
# augmented_1 and augmented_2 are identical
This is useful for techniques that require multiple views of the same augmentation (e.g., consistency regularization).