import logging
import numpy as np
import faiss
import ascendfaiss
DIM = 512
BASE = 100000
QUERY = 100
np.random.seed(1234)
xb = np.random.random((BASE, DIM)).astype('float32')
xb[:, 0] += np.arange(BASE) / 1000.
xq = xb[:QUERY, :]
dev = ascendfaiss.IntVector()
dev.push_back(0)
config = ascendfaiss.AscendIndexFlatConfig(dev)
ascend_index_flat = ascendfaiss.AscendIndexFlatL2(DIM, config)
ascend_index_flat.add(xb)
logging.basicConfig(level=logging.INFO)
k = 10
distances, indices = ascend_index_flat.search(xq, k)
logging.info("indices: %s", indices)
logging.info("distances: %s", distances)
ids_remove = faiss.IDSelectorRange(0, 1)
ids_remove_batch = indices[0][:int(k / 2)].copy()
num_removed = ascend_index_flat.remove_ids(ids_remove)
ascend_index_flat.reset()
cpu_index_flat = faiss.IndexFlatL2(DIM)
cpu_index_flat.add(xb)
dev = ascendfaiss.IntVector()
dev.push_back(0)
ascend_index_flat = ascendfaiss.index_cpu_to_ascend(dev, cpu_index_flat)
_, indices = ascend_index_flat.search(xq, k)
cpu_index_flat = ascendfaiss.index_ascend_to_cpu(ascend_index_flat)
cpu_index_flat.d = ascend_index_flat.d
cpu_index_flat.ntotal = ascend_index_flat.ntotal
_, indices = cpu_index_flat.search(xq, k)
logging.info("after search indices: %s", indices)