// 需要生成 aicpu算子+ivfsqt算子(-d 256 -c 16384)+flat_at算子(-d 256 -c 16384)



#include <unordered_map>

#include <unistd.h>

#include <algorithm>

#include <cmath>

#include <numeric>

#include <thread>

#include <vector>

#include <memory>

#include <sys/time.h>

#include <random>

#include <iostream>

#include <faiss/ascend/custom/AscendIndexIVFSQT.h>



namespace {

using recallMap = std::unordered_map<int, float>;

const int RECMAP_KEY_1 = 1;

const int RECMAP_KEY_10 = 10;

const int RECMAP_KEY_100 = 100;

const int MILLI_SECOND = 1000;

const int BITS = 8;

const int SEED = 1;

const int NPROBE = 64;

const int L2NPROBE = 360;

const int L3NPROBE = 360;

const int DIM_IN = 256;

const int DIM_OUT = 64;

void computeRecall(recallMap &recMap, int j)

{

    recMap[RECMAP_KEY_100]++;

    switch (j) {

        case 0:

            recMap[RECMAP_KEY_1]++;

            recMap[RECMAP_KEY_10]++;

            break;

        case 1 ... 9:       // case 1到9

            recMap[RECMAP_KEY_10]++;

            break;

        default:

            break;

    }

}



template<class T>

recallMap calRecallNomal(std::vector<T> label, int64_t* gt, int shape)

{

    recallMap recMap;

    recMap[RECMAP_KEY_1] = 0;

    recMap[RECMAP_KEY_10] = 0;

    recMap[RECMAP_KEY_100] = 0;

    if (shape <= 0) {

        std::cerr << "Error: Invalid shape value." << std::endl;

        return recMap;

    }

    int k = label.size() / shape;

    for (int i = 0; i < shape; i++) {

        for (int j = 0; j < k; j++) {

            if (gt[i * k] == label[i * k + j]) {

                computeRecall(recMap, j);

                break;

            }

        }

    }

    recMap[RECMAP_KEY_1] = recMap[RECMAP_KEY_1] / shape * 100;     // recMap[1]的百分比 这里的100代表的是百分比的计算因子

    recMap[RECMAP_KEY_10] = recMap[RECMAP_KEY_10] / shape * 100;   // recMap[10]的百分比 这里的100代表的是百分比的计算因子

    recMap[RECMAP_KEY_100] = recMap[RECMAP_KEY_100] / shape * 100; // recMap[100]的百分比 这里的100代表的是百分比的计算因子

    return recMap;

}



inline double GetMillisecs()

{

    struct timeval tv = {0, 0};

    gettimeofday(&tv, nullptr);

    return tv.tv_sec * 1e3 + tv.tv_usec * 1e-3;

}



std::independent_bits_engine<std::mt19937, BITS, uint8_t> engine(SEED);



int8_t RandomInt8()

{

    int offset = 128;

    int8_t i = engine() - offset;

    return i;

}



void GeneratorRandomIndex(size_t ntotal, size_t& idx)

{

    static thread_local std::mt19937 generator(std::random_device{}());

    std::uniform_int_distribution<size_t> dis(0, ntotal - 1);

    idx = dis(generator);

}

struct dataFloat {

    dataFloat(std::vector<float> &base,

        std::vector<float> &learn,

        std::vector<float> &query,

        std::vector<int64_t> &gt)

        : base(base), learn(learn), query(query), gt(gt) {};

        std::vector<float> &base;

        std::vector<float> &learn;

        std::vector<float> &query;

        std::vector<int64_t> &gt;

};



void DataProccess(int searchNum, int topk, dataFloat data, faiss::ascend::AscendIndexIVFSQT &index,

    std::vector<faiss::idx_t>& labelTopk)

{

    int flatK = 1;

    double searchStart = GetMillisecs();



    std::vector<faiss::idx_t> labelFlat(searchNum * static_cast<size_t>(topk), 0);

    std::vector<float> distanceTopk(searchNum * static_cast<size_t>(topk), 0);

    std::vector<faiss::idx_t> labelRet(searchNum * static_cast<size_t>(topk), 0);

    std::vector<float> distanceFlat(searchNum * static_cast<size_t>(flatK), 0);



    int cpuNum = 48;

    auto cpuSearchFunctor = [searchNum, topk, flatK, &distanceTopk, &labelTopk,

        &distanceFlat, &labelRet, &data] (int cpuIdx, int start, int end) {

        for (int qId = start; qId < std::min(end, searchNum); qId++) {

            cpu_set_t cpuset;

            CPU_ZERO(&cpuset);

            CPU_SET(cpuIdx, &cpuset);

            float max = 0;

            size_t offset = 0;

            for (size_t topkId = 0; topkId <  static_cast<size_t>(topk); topkId++) {

                bool isValid = false;

                if (labelTopk[qId * topk + topkId] == -1) {

                    continue;

                }

                for (size_t dimId = 0; dimId < static_cast<size_t>(DIM_IN); dimId++) {

                    // 计算IP距离

                    distanceTopk[qId * topk + topkId] +=

                    data.query[qId * DIM_IN + dimId] * \

                    data.base[labelTopk[qId * topk + topkId] * DIM_IN + dimId];

                    isValid = true;

                }

                if (isValid && max < distanceTopk[qId * topk + topkId]) {

                    max = distanceTopk[qId * topk + topkId];

                    offset = topkId;

                }

                labelRet[qId * topk] = labelTopk[qId * topk + offset];

                distanceFlat[qId * flatK] = distanceTopk[qId * topk + offset];

            }

        }

    };

    std::thread threads[cpuNum];

    int block = (searchNum + cpuNum - 1) / cpuNum; // div up

    for (int i = 0; i < cpuNum; i++) {

        threads[i] = std::thread(cpuSearchFunctor, i, i * block, (i + 1) * block);

    }

    for (int i = 0; i < cpuNum; i++) {

        threads[i].join();

    }

    double cpuSearchEnd = GetMillisecs();

    recallMap recall = calRecallNomal(labelRet, data.gt.data(), searchNum);

    // QPS = 并发线程数 * (1000 / 平均耗时ms)

    printf("searchNum = %d, r1 = %.2f, r10 = %.2f, r100 = %.2f, qps = %lf\n",

        searchNum, recall[RECMAP_KEY_1], recall[RECMAP_KEY_10], recall[RECMAP_KEY_100],

        MILLI_SECOND * searchNum / (cpuSearchEnd - searchStart));

}



void SearchProccess(faiss::ascend::AscendIndexIVFSQT &index, size_t ntotal, std::vector<float>& base,

    std::vector<float>& learn, dataFloat dataBaseFloat)

{

    int fuzzyK = 3;

    float threshold = 1.6;

    int trainSize = learn.size() / DIM_IN;

    int topk = 100;

    int searchNum = 8;

    index.verbose = true;

    index.setFuzzyK(fuzzyK);

    index.setThreshold(threshold);



    double trainStart = GetMillisecs();

    index.train(trainSize, learn.data());

    double trainEnd = GetMillisecs();

    // 将毫秒转换为秒,故 / 1000

    printf("train time cost : %.2fs\n", (trainEnd - trainStart) / MILLI_SECOND);

    double addStart = GetMillisecs();

    index.add(ntotal, base.data());

    double addEnd = GetMillisecs();

    printf("add time cost : %.2fs\n", (addEnd - addStart) / MILLI_SECOND);

    double updateStart = GetMillisecs();

    index.update();

    double updateEnd = GetMillisecs();

    printf("update time cost : %.2fs\n", (updateEnd - updateStart) / MILLI_SECOND);

    printf("=> start Qps Test!\n");



    index.updateTParams(L2NPROBE, L3NPROBE);

    index.setNumProbes(NPROBE);

    std::vector<float> dist(searchNum * static_cast<size_t>(topk), 0);

    std::vector<faiss::idx_t> labelTopk(searchNum * static_cast<size_t>(topk), 0);



    double sqtSearchStart = GetMillisecs();

    index.search(searchNum, dataBaseFloat.query.data(), topk, dist.data(), labelTopk.data());

    double sqtSearchEnd = GetMillisecs();



    printf("searchNum = %d, sqt search cost: %lf\n", searchNum, sqtSearchEnd - sqtSearchStart);

    DataProccess(searchNum, topk, dataBaseFloat, index, labelTopk);

}



void TestIVFSQT(int niter, int ncentroids)

{

    size_t ntotal = 1000000;

    size_t queryNum = 1000000;

    size_t learnNum = ntotal / 10;

    int gtNum = 100;

    int centroid = 256;

    try {

        // resource size 1024 * 1024 * 1024 = 1GB

        faiss::ascend::AscendIndexIVFSQTConfig conf({ 0 }, static_cast<int64_t>(1024 * 1024 * 1024));

        conf.cp.niter = niter;

        conf.useKmeansPP = true;

        conf.cp.max_points_per_centroid = centroid;

        faiss::ascend::AscendIndexIVFSQT index(DIM_IN, DIM_OUT, ncentroids,

             faiss::ScalarQuantizer::QuantizerType::QT_8bit, faiss::METRIC_INNER_PRODUCT, conf);

        std::vector<int8_t> baseInt8(ntotal * DIM_IN);

        for (size_t i = 0; i < ntotal * DIM_IN; i++) {

            baseInt8[i] = RandomInt8();

        }

        std::vector<int8_t> learnInt8(learnNum * DIM_IN);

        for (size_t i = 0; i < learnNum * DIM_IN; i++) {

            learnInt8[i] = RandomInt8();

        }

        std::vector<int8_t> queryInt8(queryNum * DIM_IN);

        std::vector<int64_t> gt(queryNum * gtNum, 0);

        for (size_t q = 0; q < queryNum; q++) {

            size_t idx;

            GeneratorRandomIndex(ntotal, idx);

            for (size_t d = 0; d < DIM_IN; d++) {

                queryInt8[q * DIM_IN + d] = baseInt8[idx * DIM_IN + d];

            }

            gt[q * gtNum] = static_cast<int64_t>(idx);

        }

        // int8 to float,除以128.0是为了将其映射到-1.0到1.0的区间内。

        float intToFloat = 128.0;

        std::vector<float> base(ntotal * DIM_IN);

        for (size_t i = 0; i < ntotal * DIM_IN; i++) {

            base[i] = static_cast<float>(baseInt8[i]) / intToFloat;

        }

        std::vector<float> query(queryNum * DIM_IN);

        for (size_t i = 0; i < queryNum * DIM_IN; i++) {

            query[i] = static_cast<float>(queryInt8[i]) / intToFloat;

        }

        std::vector<float> learn(learnNum * DIM_IN);

        for (size_t i = 0; i < learnNum * DIM_IN; i++) {

            learn[i] = static_cast<float>(learnInt8[i]) / intToFloat;

        }

        dataFloat dataBaseFloat(base, learn, query, gt);

        SearchProccess(index, ntotal, base, learn, dataBaseFloat);

    } catch (std::exception &e) {

        printf("%s\n", e.what());

    }

}



} // namespace



int main(int argc, char **argv)

{

    int ncentroids = 16384; // 分桶的数目

    int niter = 16;

    printf("Start Test\n");

    TestIVFSQT(niter, ncentroids);

}