Implementation:Kornia Kornia VisionTransformer
| Knowledge Sources | |
|---|---|
| Domains | Vision, Image_Classification, Transformer |
| Last Updated | 2026-02-09 15:00 GMT |
Overview
VisionTransformer implements the standard Vision Transformer (ViT) architecture that splits images into fixed-size patches and processes them through a transformer encoder, as described in the paper An Image is Worth 16x16 Words (arXiv:2010.11929).
Description
This module provides the VisionTransformer class and its constituent blocks within the Kornia library. The architecture follows the standard ViT paradigm: an image is divided into patches via PatchEmbedding, a [CLS] token and positional embeddings are added, and the sequence is passed through a TransformerEncoder composed of TransformerEncoderBlock modules. Each block consists of MultiHeadAttention with ResidualAdd connections and a FeedForward network. The implementation includes tricks from the timm library. The module supports multiple ViT variants (Ti, S, B, L, H) with optional AugReg pre-trained weights hosted on HuggingFace.
Usage
Import this module when you need a standard Vision Transformer for image feature extraction, classification, or as a backbone in larger vision pipelines. Use VisionTransformer.from_config() for convenient model creation with pre-trained weights.
Code Reference
Source Location
- Repository: Kornia
- File: kornia/models/vit.py
- Lines: 1-316
Signature
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
dropout_rate: float = 0.0,
dropout_attn: float = 0.0,
backbone: nn.Module | None = None,
) -> None: ...
def forward(self, x: torch.Tensor) -> torch.Tensor: ...
@staticmethod
def from_config(variant: str, pretrained: bool = False, **kwargs: Any) -> VisionTransformer: ...
Import
from kornia.models.vit import VisionTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| image_size | int | No | Size of the input image (default 224). |
| patch_size | int | No | Size of each image patch (default 16). |
| in_channels | int | No | Number of input channels (default 3). |
| embed_dim | int | No | Embedding dimension for the transformer encoder (default 768). |
| depth | int | No | Number of transformer encoder blocks (default 12). |
| num_heads | int | No | Number of attention heads (default 12). |
| dropout_rate | float | No | Dropout rate (default 0.0). |
| dropout_attn | float | No | Attention dropout rate (default 0.0). |
| backbone | nn.Module or None | No | Optional backbone for patch embedding computation (default None uses Conv2d). |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Encoded patch tokens of shape (B, N+1, embed_dim) where N is the number of patches and +1 is the [CLS] token. |
Key Components
PatchEmbedding
Converts 2D images into patch embeddings. Uses either a Conv2d with kernel_size=patch_size or a custom backbone. Prepends a learnable [CLS] token and adds positional embeddings.
TransformerEncoder
A sequence of TransformerEncoderBlock modules. Stores intermediate results in the results list for downstream access to features from different layers.
MultiHeadAttention
Fused QKV linear projection with scaled dot-product attention. Uses the timm trick of separate head size scaling.
FeedForward
Two-layer MLP with GELU activation and dropout, used within each transformer block.
Variant Configurations
| Variant | embed_dim | depth | num_heads |
|---|---|---|---|
| vit_ti | 192 | 12 | 3 |
| vit_s | 384 | 12 | 6 |
| vit_b | 768 | 12 | 12 |
| vit_l | 1024 | 24 | 16 |
| vit_h | 1280 | 32 | 16 |
Usage Examples
import torch
from kornia.models.vit import VisionTransformer
# Basic usage
img = torch.rand(1, 3, 224, 224)
vit = VisionTransformer(image_size=224, patch_size=16)
output = vit(img) # shape: (1, 197, 768)
# From config with pretrained AugReg weights
vit_model = VisionTransformer.from_config("vit_b/16", pretrained=True)
# Access intermediate encoder results
_ = vit(img)
intermediate_features = vit.encoder_results # list of tensors from each block