Implementation:Microsoft Onnxruntime CUDA Gist
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for memory-efficient activation compression using GIST encoding/decoding operators in the ONNX Runtime CUDA training framework.
Description
Implements multiple GIST (Gradient-Informed Stashing Technique) encoder/decoder pairs for CUDA that compress activations during the forward pass to reduce memory usage and decompress them during the backward pass. Five compression strategies are provided: (1) GistBinarize encodes values to booleans (sign bit), supported for float/MLFloat16/double; (2) GistPack1 packs boolean/float values into uint8 at 8:1 ratio using GIST_PACK1_FACTOR; (3) GistPack8 compresses float/MLFloat16 to uint8 per element; (4) GistPack16 compresses float to half precision; (5) GistPackMsfp15 uses Microsoft Floating Point 15-bit encoding with tile-based shared exponents (tile_size=8). Each pair consists of an encoder (forward) and decoder (backward) registered in kMSDomain version 1.
Usage
Used during training to compress saved activations between forward and backward passes, trading compute for memory when GPU memory is constrained.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/gist/gist.cc
- Lines: 1-333
Signature
template <typename T> Status GistBinarizeEncoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistBinarizeDecoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPack1EncoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPack1DecoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPack8EncoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPack8DecoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPack16EncoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPack16DecoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPackMsfp15EncoderOp<T>::ComputeInternal(OpKernelContext* context) const;
template <typename T> Status GistPackMsfp15DecoderOp<T>::ComputeInternal(OpKernelContext* context) const;
Import
#include "orttraining/training_ops/cuda/gist/gist.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X | Tensor(T) | Yes | Input tensor to encode (encoder) or compressed tensor to decode (decoder) |
Outputs
| Name | Type | Description |
|---|---|---|
| Y | Tensor | Compressed output (encoder) or reconstructed tensor (decoder); type depends on encoding scheme |
Usage Examples
// GistBinarize: float -> bool (encoder), bool -> float (decoder)
// GistPack1: float -> uint8 at 8:1 compression ratio
// GistPack8: float/MLFloat16 -> uint8 per element
// GistPack16: float -> MLFloat16 (half precision)
// GistPackMsfp15: float -> uint8 with shared exponents in tiles of 8