b970215e创建于 12 小时前历史提交
"""
TorchNN: Pure-PyTorch nearest neighbor search replacing FAISS.

Computes L2 distance = a² + b² - 2abᵀ via torch.mm (CANN GEMM on NPU).
No FAISS dependency needed.
"""

import logging
from typing import Union, Tuple

import numpy as np
import torch

LOGGER = logging.getLogger(__name__)


class TorchNN:
    """Nearest neighbour search using pure PyTorch (torch.mm L2 distance).

    Replaces FAISS IndexFlatL2 with a² + b² - 2abᵀ formulation,
    compatible with CANN GEMM on Ascend NPU.
    """

    def __init__(self, on_gpu: bool = False, num_workers: int = 4) -> None:
        self.on_gpu = on_gpu
        self.num_workers = num_workers
        self.search_index = None  # stored as torch.Tensor [N x D]

    def fit(self, features: np.ndarray) -> None:
        """Store index features as a torch tensor."""
        LOGGER.info(f"TorchNN fit: features shape {features.shape}")
        self.search_index = torch.from_numpy(features).float()

    def run(
        self,
        n_nearest_neighbours: int,
        query_features: np.ndarray,
        index_features: np.ndarray = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Nearest neighbour search via L2 = a² + b² - 2abᵀ.

        Args:
            n_nearest_neighbours: Number of nearest neighbours.
            query_features: [M x D] query features.
            index_features: [N x D] index features (if None, use stored index).

        Returns:
            distances: [M x K] L2 distances to K nearest neighbours.
            indices:   [M x K] indices of K nearest neighbours.
        """
        if index_features is not None:
            index = torch.from_numpy(index_features).float()
        else:
            assert self.search_index is not None, "TorchNN: no index fitted."
            index = self.search_index

        query = torch.from_numpy(query_features).float()

        # L2 distance = a² + b² - 2abᵀ
        # a: query [M, D], b: index [N, D]
        a2 = (query ** 2).sum(dim=1, keepdim=True)                    # [M, 1]
        b2 = (index ** 2).sum(dim=1, keepdim=True).T                  # [1, N]
        ab = torch.mm(query, index.T)                                 # [M, N]
        dist = (a2 + b2 - 2 * ab).clamp(min=0).sqrt()                # [M, N]

        # Get K nearest neighbours (smallest distances)
        if n_nearest_neighbours >= dist.shape[1]:
            distances, indices = dist, torch.arange(dist.shape[1]).unsqueeze(0).expand(dist.shape[0], -1)
        else:
            distances, indices = torch.topk(dist, n_nearest_neighbours, dim=1, largest=False)

        return distances.numpy(), indices.numpy()

    def save(self, filename: str) -> None:
        """Save index tensor to file."""
        if self.search_index is not None:
            torch.save(self.search_index, filename)
            LOGGER.info(f"TorchNN index saved to {filename}")

    def load(self, filename: str) -> None:
        """Load index tensor from file."""
        self.search_index = torch.load(filename, weights_only=False)
        LOGGER.info(f"TorchNN index loaded from {filename}, shape {self.search_index.shape}")

    def reset_index(self):
        """Reset the search index."""
        self.search_index = None