"""
TorchNN: Pure-PyTorch nearest neighbor search replacing FAISS.
Computes L2 distance = a² + b² - 2abᵀ via torch.mm (CANN GEMM on NPU).
No FAISS dependency needed.
"""
import logging
from typing import Union, Tuple
import numpy as np
import torch
LOGGER = logging.getLogger(__name__)
class TorchNN:
"""Nearest neighbour search using pure PyTorch (torch.mm L2 distance).
Replaces FAISS IndexFlatL2 with a² + b² - 2abᵀ formulation,
compatible with CANN GEMM on Ascend NPU.
"""
def __init__(self, on_gpu: bool = False, num_workers: int = 4) -> None:
self.on_gpu = on_gpu
self.num_workers = num_workers
self.search_index = None
def fit(self, features: np.ndarray) -> None:
"""Store index features as a torch tensor."""
LOGGER.info(f"TorchNN fit: features shape {features.shape}")
self.search_index = torch.from_numpy(features).float()
def run(
self,
n_nearest_neighbours: int,
query_features: np.ndarray,
index_features: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Nearest neighbour search via L2 = a² + b² - 2abᵀ.
Args:
n_nearest_neighbours: Number of nearest neighbours.
query_features: [M x D] query features.
index_features: [N x D] index features (if None, use stored index).
Returns:
distances: [M x K] L2 distances to K nearest neighbours.
indices: [M x K] indices of K nearest neighbours.
"""
if index_features is not None:
index = torch.from_numpy(index_features).float()
else:
assert self.search_index is not None, "TorchNN: no index fitted."
index = self.search_index
query = torch.from_numpy(query_features).float()
a2 = (query ** 2).sum(dim=1, keepdim=True)
b2 = (index ** 2).sum(dim=1, keepdim=True).T
ab = torch.mm(query, index.T)
dist = (a2 + b2 - 2 * ab).clamp(min=0).sqrt()
if n_nearest_neighbours >= dist.shape[1]:
distances, indices = dist, torch.arange(dist.shape[1]).unsqueeze(0).expand(dist.shape[0], -1)
else:
distances, indices = torch.topk(dist, n_nearest_neighbours, dim=1, largest=False)
return distances.numpy(), indices.numpy()
def save(self, filename: str) -> None:
"""Save index tensor to file."""
if self.search_index is not None:
torch.save(self.search_index, filename)
LOGGER.info(f"TorchNN index saved to {filename}")
def load(self, filename: str) -> None:
"""Load index tensor from file."""
self.search_index = torch.load(filename, weights_only=False)
LOGGER.info(f"TorchNN index loaded from {filename}, shape {self.search_index.shape}")
def reset_index(self):
"""Reset the search index."""
self.search_index = None