Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Kornia Kornia Differentiable Training Augmentation

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Vision, Training
Last Updated 2026-02-09 15:00 GMT

Overview

Technique of integrating GPU-accelerated differentiable augmentations directly into the training loop with gradient preservation.

Description

Unlike CPU-based augmentation (e.g., torchvision transforms), differentiable augmentation runs on GPU and preserves gradients through the augmentation operations. This enables:

  1. Augmentation as part of the computation graph for adversarial training or augmentation optimization, where augmentation parameters can be learned jointly with the model.
  2. Significantly faster augmentation by leveraging GPU parallelism, avoiding CPU-to-GPU data transfer bottlenecks.
  3. Batch-consistent augmentation where the same random transform is optionally applied to the entire batch via the same_on_batch parameter.

Because augmentation operations are implemented as differentiable PyTorch operations, the entire pipeline from augmented input to loss computation forms a single computation graph. This is fundamentally different from pre-computed offline augmentation or CPU-side transforms that break the gradient chain.

Usage

Use when:

  • Augmentation speed is critical -- GPU augmentation is faster than CPU augmentation for batch processing
  • Gradient flow through augmentation is needed -- required for adversarial training, differentiable data augmentation (DDA), or learned augmentation strategies
  • Augmentation parameters must be learned during training -- the augmentation module's parameters can receive gradients

Theoretical Basis

Given loss L = l(f(T(x)), y), the gradient with respect to model parameters includes the augmentation:

# Standard case: gradient through augmented input
# dL/dtheta = dL/df * df/d(T(x))

# Learnable augmentation case: gradient also flows to augmentation params
# dL/dtheta_aug = dL/df * df/dT * dT/dtheta_aug

The .forward() method of AugmentationSequential preserves this gradient chain. Pre-computed params enable reproducible augmentation, where the same random parameters can be applied to multiple forward passes:

# Reproducible augmentation via pre-computed params
augmented_1 = aug(images, params=saved_params)
augmented_2 = aug(images, params=saved_params)
# augmented_1 and augmented_2 are identical

This is useful for techniques that require multiple views of the same augmentation (e.g., consistency regularization).

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment