Implementation:Neuml Txtai Torch ANN
| Knowledge Sources | |
|---|---|
| Domains | Vector_Search, GPU_Computing, Quantization |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
Torch is a GPU-accelerated approximate nearest neighbor index backed by PyTorch tensors, with support for 4-bit and 8-bit quantization via the bitsandbytes library.
Description
The Torch class inherits from NumPy and replaces NumPy array operations with their PyTorch equivalents (torch.mm, torch.cat, etc.) to leverage GPU acceleration via CUDA. It adds quantization support through bitsandbytes: 8-bit integer quantization using int8_vectorwise_quant and 4-bit float quantization using quantize_4bit with configurable block size and quantization type (e.g., nf4). A QuantizeContext context manager transparently dequantizes data before modifications (index, append, delete) and re-quantizes afterward. Quantization state is persisted using the safetensors format.
Usage
Use Torch when you need GPU-accelerated nearest neighbor search or want to reduce memory usage through quantization. It is particularly beneficial for large embedding indexes that would otherwise exceed GPU memory, as 4-bit or 8-bit quantization can reduce memory consumption by 4x to 8x while maintaining acceptable search quality.
Code Reference
Source Location
- Repository: Neuml_Txtai
- File: src/python/txtai/ann/dense/torch.py
- Lines: 1-223
Signature
class Torch(NumPy):
def __init__(self, config):
"""
Creates a new Torch ANN index.
Args:
config: index configuration dict
"""
Import
from txtai.ann.dense import Torch
Key Methods
| Method | Description |
|---|---|
index(embeddings) |
Creates the index from embeddings. Wraps the parent method in a QuantizeContext to apply quantization after indexing.
|
append(embeddings) |
Appends new embeddings. Dequantizes before appending and re-quantizes after via QuantizeContext.
|
delete(ids) |
Soft-deletes entries. Dequantizes, zeros out rows, tracks deleted count for quantized data, then re-quantizes. |
count() |
Returns the count of active entries. For quantized data, uses the stored shape minus the deleted count. |
quantize() |
Quantizes the backend tensor in-place using either 8-bit integer or 4-bit float quantization based on config. |
dequantize() |
Dequantizes the backend tensor in-place to restore full-precision data for modifications. |
matmul8bit(query, data) |
Performs 8-bit integer matrix multiplication using bitsandbytes for efficient search on quantized data. |
matmul4bit(query, data) |
Performs 4-bit float matrix multiplication using bitsandbytes matmul_4bit.
|
QuantizeContext
The QuantizeContext class is a context manager that facilitates safe modifications to quantized tensors.
class QuantizeContext:
"""
Quantization context. Facilitates modifications to quantized tensors.
"""
def __init__(self, ann):
self.ann = ann
def __enter__(self):
self.ann.dequantize()
def __exit__(self, exc_type, exc_val, exc_tb):
self.ann.quantize()
On entry, it dequantizes the backend tensor to full precision. On exit, it re-quantizes the tensor. This ensures that index, append, and delete operations work on full-precision data while keeping the index quantized at rest.
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | dict | Yes | Index configuration. Key options include quantize (boolean for default 4-bit, or dict with type ("int8", "nf4", etc.), blocksize), dimensions, and backend name.
|
| embeddings | numpy.ndarray | Yes (for index/append) | 2D NumPy array of normalized embedding vectors. Converted to PyTorch tensors and optionally moved to GPU. |
| queries | numpy.ndarray | Yes (for search) | 2D NumPy array of query vectors. Converted to PyTorch tensors for GPU-accelerated search. |
| limit | int | Yes (for search) | Maximum number of nearest neighbors to return per query. |
| ids | list of int | Yes (for delete) | Row indices to soft-delete. |
Outputs
| Name | Type | Description |
|---|---|---|
| search results | list of list of tuple | For each query, a list of (id, score) tuples sorted by descending similarity.
|
| count | int | Number of active (non-deleted) entries in the index. |
Usage Examples
Basic Usage
import numpy as np
from txtai.ann.dense.torch import Torch
# Configuration with 8-bit quantization
config = {
"dimensions": 384,
"offset": 0,
"backend": "torch",
"torch": {
"quantize": {"type": "int8"},
"safetensors": True,
},
}
ann = Torch(config)
# Generate normalized embeddings
embeddings = np.random.rand(5000, 384).astype(np.float32)
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Build quantized index
ann.index(embeddings)
print(f"Indexed {ann.count()} vectors")
# Search
query = np.random.rand(1, 384).astype(np.float32)
query = query / np.linalg.norm(query, axis=1, keepdims=True)
results = ann.search(query, limit=5)
for uid, score in results[0]:
print(f"ID: {uid}, Score: {score:.4f}")
4-bit Quantization
import numpy as np
from txtai.ann.dense.torch import Torch
# Configuration with 4-bit nf4 quantization
config = {
"dimensions": 768,
"offset": 0,
"backend": "torch",
"torch": {
"quantize": {"type": "nf4", "blocksize": 64},
"safetensors": True,
},
}
ann = Torch(config)
embeddings = np.random.rand(10000, 768).astype(np.float32)
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
ann.index(embeddings)
print(f"Indexed {ann.count()} vectors with 4-bit quantization")