/* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
        limitations under the License.
==============================================================================*/

#include <fstream>
#include <stdexcept>
#include <string>

#include "gtest/gtest.h"
#include "mpi.h"
#include "emock/emock.hpp"

#include "utils/common.h"
#include "error/error.h"
#include "ssd_engine/file.h"

using namespace std;
using namespace MxRec;
using namespace testing;

TEST(File, CreateEmptyFile)
{
    int rankId;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);
    GlogConfig::gRankId = to_string(rankId);

    string fileDir = GlogConfig::gRankId;
    bool isExceptionThrown = false;
    try {
        auto f = make_shared<File>(0, fileDir);
    } catch (runtime_error& e) {
        isExceptionThrown = true;
        LOG_ERROR(e.what());
    }
    ASSERT_EQ(isExceptionThrown, false);
    fs::remove_all(fileDir);
}

TEST(File, LoadFromFile)
{
    // prepare
    int rankId;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);
    GlogConfig::gRankId = to_string(rankId);

    string fileDir = GlogConfig::gRankId;
    if (!fs::exists(fs::absolute(fileDir)) && !fs::create_directories(fs::absolute(fileDir))) {
        throw runtime_error("fail to create Save directory");
    }

    emb_key_t key = 0;
    offset_t offset = 0;
    vector<float> val = {1.0};

    fstream localFileMeta;
    localFileMeta.open(fileDir + "/0.meta.0", ios::out | ios::trunc | ios::binary);
    localFileMeta.write(reinterpret_cast<char const*>(&key), sizeof(key));
    localFileMeta.write(reinterpret_cast<char const*>(&offset), sizeof(offset));
    localFileMeta.flush();
    if (localFileMeta.fail()) {
        throw runtime_error("fail to prepare meta file");
    }
    localFileMeta.close();

    fstream localFileData;
    localFileData.open(fileDir + "/0.data.0", ios::out | ios::trunc | ios::binary);
    uint64_t embSize = val.size();
    localFileData.write(reinterpret_cast<char const*>(&embSize), sizeof(embSize));
    localFileData.write(reinterpret_cast<char const*>(val.data()), val.size() * sizeof(float));
    localFileData.flush();
    if (localFileData.fail()) {
        throw runtime_error("fail to prepare data file");
    }
    localFileData.close();

    // start test
    bool isExceptionThrown = false;
    string loadDir = fileDir;  // for test convenience
    try {
        auto f = make_shared<File>(0, fileDir, loadDir, 0);
    } catch (runtime_error& e) {
        LOG_ERROR(e.what());
        isExceptionThrown = true;
    }
    ASSERT_EQ(isExceptionThrown, false);
    fs::remove_all(fileDir);
}

TEST(File, WriteAndRead)
{
    int rankId;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);
    GlogConfig::gRankId = to_string(rankId);

    string savePath = GlogConfig::gRankId;
    auto f = make_shared<File>(0, savePath);

    vector<emb_cache_key_t> keys;
    vector<vector<float>> embeddings;
    for (emb_cache_key_t k = 0; k < 10; k++) {
        keys.emplace_back(k);
        vector<float> emb = {static_cast<float>(k + 0.1), static_cast<float>(k + 0.2)};
        embeddings.emplace_back(emb);
    }

    f->InsertEmbeddings(keys, embeddings);
    auto ret = f->FetchEmbeddings(keys);
    ASSERT_EQ(embeddings, ret);

    f->DeleteEmbedding(0);
    ASSERT_EQ(f->IsKeyExist(0), false);

    fs::remove_all(savePath);
}

TEST(File, SaveAndLoad)
{
    int rankId;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);
    GlogConfig::gRankId = to_string(rankId);

    int saveStep = 0;
    string fileDir = GlogConfig::gRankId;
    auto fTmp = make_shared<File>(0, fileDir);

    vector<emb_cache_key_t> key = {0};
    vector<vector<float>> expect = {{1.0, 1.1}};
    fTmp->InsertEmbeddings(key, expect);
    string saveDir = fileDir;  // for test convenience
    fTmp->Save(saveDir, saveStep);

    string loadDir = fileDir;  // for test convenience
    auto fLoad = make_shared<File>(0, fileDir, loadDir, saveStep);
    auto actual = fLoad->FetchEmbeddings(key);
    ASSERT_EQ(expect, actual);

    fs::remove_all(fileDir);
}

TEST(File, WriteByAddrAndRead)
{
    int rankId;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);
    GlogConfig::gRankId = to_string(rankId);

    string savePath = GlogConfig::gRankId;
    auto f = make_shared<File>(0, savePath);

    vector<emb_cache_key_t> keys;
    vector<float*> embeddings;
    uint64_t extEmbeddingSize = 1;
    for (emb_cache_key_t k = 0; k < 10; k++) {
        keys.emplace_back(k);
        float* emb = new float;
        *emb = static_cast<float>(k + 0.1);
        embeddings.emplace_back(emb);
    }

    f->InsertEmbeddingsByAddr(keys, embeddings, extEmbeddingSize);
    auto ret = f->FetchEmbeddings(keys);
    for (int i = 0; i < 10; i++) {
        if (std::abs(ret[i][0] - *embeddings[i]) > std::numeric_limits<float>::epsilon()) {
            FAIL() << "embedding result not equal to input";
        }
    }

    for (auto emb : embeddings) {
        delete emb;
        emb = nullptr;
    }

    fs::remove_all(savePath);
}

TEST(File, SaveAndLoadForIncrementalCkpt)
{
    int rankId;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);
    GlogConfig::gRankId = to_string(rankId);

    int saveStep = 0;
    string fileDir = GlogConfig::gRankId;
    auto fTmp = make_shared<File>(0, fileDir);

    vector<emb_cache_key_t> fullKeys = {0, 1};
    vector<vector<float>> fullEmbeddings = {{0.1, 0.2}, {0.3, 0.4}};
    vector<emb_cache_key_t> expectedKeys = {1};
    vector<vector<float>> expectedEmbeddings = {{0.3, 0.4}};
    fTmp->InsertEmbeddings(fullKeys, fullEmbeddings);
    string saveDir = fileDir;  // for test convenience
    map<emb_key_t, KeyInfo> keyInfo = {{1, KeyInfo()}};
    fTmp->Save(saveDir, saveStep, keyInfo);
    // In incremental ckpt, saved model's name contains 'delta-', but File interface has no 'delta-' for loading model.
    // So, we need to rename old model's name.
    string oldMetaFile = saveDir + "/" + "delta-0" + ".meta." + std::to_string(saveStep);
    string newMetaFile = saveDir + "/" + "0" + ".meta." + std::to_string(saveStep);
    string oldDataFile = saveDir + "/" + "delta-0" + ".data." + std::to_string(saveStep);
    string newDataFile = saveDir + "/" + "0" + ".data." + std::to_string(saveStep);
    std::fstream f;
    f.open(oldMetaFile.c_str());
    if (f.fail()) {
        std::cout << "File open failed." << std::endl;
        f.close();
    } else {
        f.close();
        if (rename(oldMetaFile.c_str(), newMetaFile.c_str()) == -1) {
            std::cout << "Rename file failed." << std::endl;
        }
    }
    f.open(oldDataFile.c_str());
    if (f.fail()) {
        std::cout << "File open failed." << std::endl;
        f.close();
    } else {
        f.close();
        if (rename(oldDataFile.c_str(), newDataFile.c_str()) == -1) {
            std::cout << "Rename file failed." << std::endl;
        }
    }
    string loadDir = fileDir;  // for test convenience
    auto fLoad = make_shared<File>(0, fileDir, loadDir, saveStep);
    auto actualEmbeddings = fLoad->FetchEmbeddings(expectedKeys);
    ASSERT_EQ(expectedEmbeddings, actualEmbeddings);
    fs::remove_all(fileDir);
}

class EmbFileTest : public ::testing::Test {
protected:
    void SetUp() override
    {
        emock::GlobalMockObject::reset();
    }
};

TEST_F(EmbFileTest, AllGetMethodsOk)
{
    int rankId = 0;
    MPI_Comm_rank(MPI_COMM_WORLD, &rankId);

    auto fileDir = "file-test-dir-" + std::to_string(rankId);
    auto file = File(0, fileDir);

    EXPECT_EQ(file.GetFileID(), 0);
    EXPECT_EQ(file.GetStaleDataCnt(), 0);
    EXPECT_EQ(file.GetDataCnt(), 0);
    EXPECT_EQ(file.GetKeys().size(), 0);
}

TEST_F(EmbFileTest, AllThrowMethodsOK)
{
    auto emptyErrMsg = ""s;
    EXPECT_THROW(File::ThrowRuntimeError(ErrorType::ACL_ERROR, emptyErrMsg), std::runtime_error);
    EXPECT_THROW(File::ThrowInvalidArgError(ErrorType::ACL_ERROR, emptyErrMsg), std::runtime_error);
}