/*
 * Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// 需要生成aicpu算子+ivfflat算子+ivfpq算子(-d 128 -c 1024)

#include <faiss/ascend/AscendIndexIVFPQ.h>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <random>
#include <vector>
#include <iostream>
#include <cfloat>

void Norm(float *data, size_t n, size_t dim)
{
#pragma omp parallel for if (n>1)
    for (size_t i = 0; i < n; ++i) {
        float l2norm = 0.0;
        for (size_t j = 0; j < dim; ++j) {
            l2norm += data[i * dim + j] * data[i * dim + j];
        }
        l2norm = std::sqrt(l2norm);
        if (fabs(l2norm) < FLT_EPSILON) {
            std::cerr << "Error: Invalid l2norm value." << std::endl;
        }
        for (size_t j = 0; j < dim; ++j) {
            data[i * dim + j] = data[i * dim + j] / l2norm;
        }
    }
}

int main()
{
    size_t dim = 128;
    size_t ntotal = 1000000;
    int ncentroids = 1024;
    int nprobe = 32;
    int M = 4;
    int nbits = 8;

    printf("generate data\n");
    std::vector<float> data(dim * ntotal);
    for (size_t i = 0; i < data.size(); i++) {
        data[i] = drand48();
    }
    Norm(data.data(), ntotal, dim);
    std::vector<int64_t> ids(ntotal);
    for (size_t i = 0; i < ids.size(); i++) {
        ids[i] = i;
    }

    faiss::ascend::AscendIndexIVFPQ *index = nullptr;
    try {
        faiss::ascend::AscendIndexIVFPQConfig conf{{0}};
        printf("create index\n");
        index = new faiss::ascend::AscendIndexIVFPQ(dim, faiss::METRIC_L2, ncentroids, M, nbits, conf);
        index->verbose = true;
        index->setNumProbes(nprobe);

        printf("start train\n");
        index->train(ntotal, data.data());
        printf("start add\n");
        index->add_with_ids(ntotal, data.data(), ids.data());

        size_t n = 10;
        size_t k = 10;
        std::vector<float> dist(n * k, 0.0);
        std::vector<faiss::idx_t> label(n * k, 0);
        printf("start search\n");
        index->search(n, data.data(), k, dist.data(), label.data());
    } catch (std::exception &e) {
        printf("exceptin caught: %s\n", e.what());
        delete index;
        return -1;
    }
    delete index;
    printf("search success\n");
    return 0;
}