/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
 * MemFabric_Hybrid is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
*/
#include <iostream>
#include <sstream>
#include <map>
#include <vector>
#include <iomanip>
#include <algorithm>
#include <sys/time.h>
#include <unistd.h>
#include <random>
#include <cstring>
#include <thread>
#include <chrono>
#include <atomic>
#include <mutex>

#include "acl/acl.h"

#include "smem.h"
#include "smem_shm.h"
#include "smem_trans.h"

#define CHECK_ACL(x)                                                                        \
    do {                                                                                    \
        aclError __ret = x;                                                                 \
        if (__ret != ACL_ERROR_NONE) {                                                      \
            std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret << std::endl; \
        }                                                                                   \
    } while (0)

#define LOG_ERROR(msg) std::cout << __FILE__ << ":" << __LINE__ << "[ERR]" << msg << std::endl

#define CHECK_GOTO_ERR(x, msg, lable) \
    do {                              \
        if ((x) != 0) {               \
            LOG_ERROR(msg);           \
            goto lable;               \
        }                             \
    } while (0)

const uint64_t GVA_SIZE = 1024ULL * 1024 * 1024;
constexpr uint32_t MAX_UINT32 = 0xFFFFFFFF;
constexpr uint32_t SEP_HALF_WIDTH = 50;
constexpr uint32_t RUN_TRANS_PERF_WITH_DRAM_AND_HBM = 2;
constexpr uint32_t RUN_TRANS_PERF_WITH_DRAM_ONLY = 1;
constexpr uint32_t RUN_TRANS_PERF_WITH_HBM_ONLY = 0;
constexpr uint32_t TRANS_PERF_CONCURRENCY = 2;

const static std::map<std::string, uint64_t> RATE_UNIT_MP = {{"GB", 1000ull * 1000ull * 1000ull},
                                                             {"GiB", 1ull << 30},
                                                             {"Gb", 1000ull * 1000ull * 1000ull / 8},
                                                             {"MB", 1000ull * 1000ull},
                                                             {"MiB", 1ull << 20},
                                                             {"Mb", 1000ull * 1000ull / 8},
                                                             {"KB", 1000ull},
                                                             {"KiB", 1ull << 10},
                                                             {"Kb", 1000ull / 8}};

static inline std::string calculateRate(uint64_t data_bytes, double duration)
{
    const uint64_t MEGABYTES_PER_BYTE = 1000000;
    std::string report_unit = "GB";
    std::ostringstream oss;
    oss << std::fixed << std::setprecision(2U)
        << 1.0 * data_bytes * MEGABYTES_PER_BYTE / duration / RATE_UNIT_MP.at(report_unit) << " " << report_unit
        << "/s";
    return oss.str();
}

static inline void init_warmup_data(char *&warmup_data, size_t length)
{
    const size_t STEP_SIZE = 4;
    uint64_t *p;
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint64_t> dis(0, MAX_UINT32);
    p = reinterpret_cast<uint64_t *>(warmup_data);
    for (uint64_t i = 0; i < length; i += STEP_SIZE) {
        p = reinterpret_cast<uint64_t *>(&warmup_data[i]);
        *p = static_cast<uint64_t>(dis(gen));
    }
}


int32_t trans_perf_test(smem_trans_t trans_handle, smem_shm_t shm_handle, int rank_id, int use_malloc,
                        int num_threads = 1)
{
    char *warmup_data = nullptr;
    int32_t ret = 0;
    void *dev_addr = nullptr;
    void *gather_addr[2];
    void *dst_dev_addr = nullptr;
    const uint32_t KB_SIZE = 1024;

    // malloc device mem
    if (use_malloc) {
        dev_addr = smem_trans_malloc(trans_handle, GVA_SIZE);
        if (dev_addr == nullptr || dev_addr == nullptr) {
            std::cout << "malloc dram failed" << std::endl;
            return -1;
        } else {
            std::cout << "malloc dram success, dev_addr:" << dev_addr << std::endl;
        }
    } else {
        aclError aclret = aclrtMalloc(&dev_addr, GVA_SIZE, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_GOTO_ERR(aclret != ACL_ERROR_NONE, "failed to allocate device memory, ret:" << aclret, out);
    }

    std::cout << "[" << rank_id << "]" << " malloc dev mem " << dev_addr << std::endl;
    /* gather peer addr */
    ret = smem_shm_control_allgather(shm_handle, (char *)&dev_addr, sizeof(void *), (char *)gather_addr,
                                     sizeof(void *) * 2U);
    CHECK_GOTO_ERR(ret, "failed to allgather dev memory, ret:" << ret, out);

    ret = smem_shm_control_barrier(shm_handle);
    CHECK_GOTO_ERR(ret, "barrier failed, ret:" << ret << " rank:" << rank_id, out);

    if (rank_id == 1 && !use_malloc) {
        ret = smem_trans_register_mem(trans_handle, dev_addr, GVA_SIZE, 0);
        CHECK_GOTO_ERR(ret, "failed to register device memory, ret:" << ret, out);
    }
    ret = smem_shm_control_barrier(shm_handle);
    std::this_thread::sleep_for(std::chrono::seconds(10UL)); // wait for register
    CHECK_GOTO_ERR(ret, "barrier failed, ret:" << ret << " rank:" << rank_id, out);

    if (rank_id == 0) {
        uint32_t block_iteration = 10;
        uint32_t base_block_size = 32 << 10; // 32k
        uint32_t times = 100;
        uint32_t batch_size = 32;
        std::string dstSessionId = "127.0.0.1:10001";
        dst_dev_addr = gather_addr[1];
        CHECK_GOTO_ERR(!dst_dev_addr, "dev memory error", out);
        std::cout << "[" << rank_id << "]" << " get dst dev addr " << dst_dev_addr << std::endl;

        /* init warmup data */
        warmup_data = (char *)malloc(GVA_SIZE * sizeof(char));
        CHECK_GOTO_ERR(!warmup_data, "warmup data malloc failed", out);
        std::cout << "Warmup Start" << std::endl;
        init_warmup_data(warmup_data, GVA_SIZE);
        if (!use_malloc) {
            aclrtMemcpy(dev_addr, GVA_SIZE, warmup_data, GVA_SIZE, ACL_MEMCPY_HOST_TO_DEVICE);
        } else {
            memcpy(dev_addr, warmup_data, GVA_SIZE);
        }

        // warmup
        ret = smem_trans_write(trans_handle, dev_addr, dstSessionId.c_str(), dst_dev_addr, base_block_size, 0);
        CHECK_GOTO_ERR(ret, "trans copy failed, ret:" << ret << " rank:" << rank_id, out);
        std::cout << "Warmup End" << std::endl;
        CHECK_GOTO_ERR(ret, "barrier failed, ret:" << ret << " rank:" << rank_id, out);

        auto test_title = use_malloc ? "Dram Trans Test Start" : "HBM Trans Test Start";
        std::string separator(SEP_HALF_WIDTH, '=');
        std::cout << separator << test_title << separator << std::endl;

        for (uint32_t i = 0; i < block_iteration; i++) {
            uint32_t block_size = base_block_size * (1 << i);
            struct timeval start_tv, stop_tv;

            gettimeofday(&start_tv, nullptr);

            std::vector<std::thread> threads;
            std::atomic<int32_t> err_count{0};

            for (int t = 0; t < num_threads; ++t) {
                threads.emplace_back([&, t]() {
                    for (uint32_t j = 0; j < times; j++) {
                        int32_t local_ret =
                            smem_trans_write(trans_handle, dev_addr, dstSessionId.c_str(), dst_dev_addr, block_size, 0);
                        if (local_ret != 0) {
                            err_count++;
                            return;
                        }
                    }
                });
            }

            for (auto &th : threads) {
                if (th.joinable())
                    th.join();
            }

            if (err_count.load() > 0) {
                std::cerr << "Latency test failed with errors" << std::endl;
                ret = -1;
                goto out;
            }

            gettimeofday(&stop_tv, nullptr);
            double duration1 = (stop_tv.tv_sec - start_tv.tv_sec) * 1000000.0 + (stop_tv.tv_usec - start_tv.tv_usec);

            duration1 /= (num_threads * times);

            std::vector<void *> laddrv;
            std::vector<void *> raddrv;
            std::vector<uint64_t> lengthv;
            laddrv.reserve(batch_size);
            raddrv.reserve(batch_size);
            lengthv.reserve(batch_size);
            for (uint32_t j = 0; j < batch_size; j++) {
                void *laddr = (uint8_t *)dev_addr + j * block_size;
                void *raddr = (uint8_t *)dst_dev_addr + j * block_size;
                laddrv.push_back(laddr);
                raddrv.push_back(raddr);
                lengthv.push_back(block_size);
            }

            gettimeofday(&start_tv, nullptr);

            err_count = 0;
            threads.clear();

            for (int t = 0; t < num_threads; ++t) {
                threads.emplace_back([&, t]() {
                    for (uint32_t j = 0; j < times; j++) {
                        int32_t local_ret = smem_trans_batch_write(
                            trans_handle, const_cast<const void **>(laddrv.data()), dstSessionId.c_str(), raddrv.data(),
                            reinterpret_cast<size_t *>(lengthv.data()), lengthv.size(), 0);
                        if (local_ret != 0) {
                            err_count++;
                            return;
                        }
                    }
                });
            }

            for (auto &th : threads) {
                if (th.joinable())
                    th.join();
            }

            if (err_count.load() > 0) {
                std::cerr << "BW test failed with errors" << std::endl;
                ret = -1;
                goto out;
            }

            gettimeofday(&stop_tv, nullptr);
            double duration2 = (stop_tv.tv_sec - start_tv.tv_sec) * 1000000.0 + (stop_tv.tv_usec - start_tv.tv_usec);
            duration2 /= (num_threads * times);
            uint64_t total_bytes = static_cast<uint64_t>(num_threads) * times * batch_size * block_size;
            double total_time_us =
                (stop_tv.tv_sec - start_tv.tv_sec) * 1000000.0 + (stop_tv.tv_usec - start_tv.tv_usec);

            std::cout << "Test completed: latency " << duration1 << "us, block size " << (block_size / KB_SIZE)
                      << "KB, "
                      << "total threads=" << num_threads << ", per-thread times=" << times << ", "
                      << "aggregated throughput " << calculateRate(total_bytes, total_time_us) << std::endl;
        }
        std::cout << separator << "Test End" << separator << std::endl;
    }

    smem_shm_control_barrier(shm_handle);
out:
    if (warmup_data) {
        free(warmup_data);
    }
    if (dev_addr) {
        if (use_malloc) {
            ret = smem_trans_free(trans_handle, dev_addr);
            if (ret != 0) {
                std::cout << "free dram failed, dev_addr:" << dev_addr << std::endl;
            } else {
                std::cout << "free dram success, dev_addr:" << dev_addr << std::endl;
            }
        } else {
            aclrtFree(dev_addr);
        }
    }
    return ret;
}


int32_t trans_test(int rank_id, int rank_size, int device_id, int use_sdma, std::string &ip_port, int memType)
{
    void *shm_gva = nullptr;
    smem_trans_config_t config;
    smem_shm_config_t config2;
    smem_trans_t trans_handle;
    smem_shm_t shm_handle;
    std::string sessionId;
    int32_t ret;

    if (rank_id == 0) {
        ret = smem_create_config_store(ip_port.c_str(), SMEM_STORE_SKIP_RECOVER);
        CHECK_GOTO_ERR(ret, "smem create config store failed, ret:" << " rank:" << rank_id, err1);
    }

    std::cout << std::endl << std::endl;

    smem_trans_config_init(&config);
    if (rank_id == 0) {
        /* Prefill */
        config.role = SMEM_TRANS_SENDER;
        sessionId = "127.0.0.1:10000";
    } else {
        /* Decode */
        config.role = SMEM_TRANS_RECEIVER;
        sessionId = "127.0.0.1:10001";
    }
    config.deviceId = device_id;
    if (use_sdma) {
        config.dataOpType = SMEMB_DATA_OP_SDMA;
    } else {
        config.dataOpType = SMEMB_DATA_OP_DEVICE_RDMA;
    }
    
    ret = smem_trans_init(&config);
    if (ret != 0) {
        std::cout << "[Failed to init smem_trans, ret=" << ret << "]" << std::endl;
        return ret;
    }
    trans_handle = smem_trans_create(ip_port.c_str(), sessionId.c_str(), &config);
    CHECK_GOTO_ERR(!trans_handle, "smem trans create failed, ret:" << " rank:" << rank_id, err1);
    std::cout << "[" << rank_id << "]" << " smem trans create done" << std::endl;

    (void)smem_shm_config_init(&config2);
    config2.startConfigStoreServer = false;

    ret = smem_shm_init(ip_port.c_str(), rank_size, rank_id, device_id, &config2);
    CHECK_GOTO_ERR(ret, "smem shmem init failed, ret:" << ret << " rank:" << rank_id, err2);

    shm_handle = smem_shm_create(0, rank_size, rank_id, GVA_SIZE, SMEMS_DATA_OP_MTE, 0, &shm_gva);
    CHECK_GOTO_ERR((shm_handle == nullptr), "smem_shm_create failed, rank:" << rank_id, err3);
    std::cout << "[" << rank_id << "]" << " smem shmem create done" << std::endl;

    ret = smem_shm_control_barrier(shm_handle);
    CHECK_GOTO_ERR(ret, "barrier failed, ret:" << ret << " rank:" << rank_id, err4);
    if (memType == RUN_TRANS_PERF_WITH_DRAM_AND_HBM) {
        trans_perf_test(trans_handle, shm_handle, rank_id, RUN_TRANS_PERF_WITH_HBM_ONLY, TRANS_PERF_CONCURRENCY);
        trans_perf_test(trans_handle, shm_handle, rank_id, RUN_TRANS_PERF_WITH_DRAM_ONLY, TRANS_PERF_CONCURRENCY);
    } else {
        trans_perf_test(trans_handle, shm_handle, rank_id, memType, TRANS_PERF_CONCURRENCY);
    }

err4:
    smem_shm_destroy(shm_handle, 0);
err3:
    smem_shm_uninit(0);
err2:
    smem_trans_destroy(trans_handle, 0);
    smem_trans_uninit(0);
err1:
    return 0;
}

/*
 * transfer_perf {rank_size} {rank_id} {deviceID} {use_sdma} tcp://{Ip}:{port} {memType}
 * transfer_perf 2 0 2 1 tcp://127.0.0.1:12050 2
 */

int32_t main(int32_t argc, char *argv[])
{
    int rank_size = atoi(argv[1]);
    int rank_id = atoi(argv[2]);
    int device_id = atoi(argv[3]);
    int use_sdma = atoi(argv[4]);
    std::string ipPort = argv[5];
    int memType = atoi(argv[6]);
    std::cout << "[TEST] input rank_size: " << rank_size << " rank_id:" << rank_id << " device_id: " << device_id <<
        " use_sdma: " << use_sdma << " store_ip: " << ipPort <<
        " memType(0:hbm 1:dram 2:hbm + dram):" << memType << std::endl;

    const size_t RANK_ID_SIZE = 2;
    if (rank_size != RANK_ID_SIZE) {
        std::cout << "[TEST] input rank_size: " << rank_size << " is not 2" << std::endl;
        return -1;
    }

    /* init aclrt */
    CHECK_ACL(aclInit(nullptr));
    CHECK_ACL(aclrtSetDevice(device_id));

    const uint32_t LOG_LEVEL_WARNING = 2;
    smem_set_log_level(LOG_LEVEL_WARNING);
    auto ret = smem_init(0);
    CHECK_GOTO_ERR(ret, "smem init failed, ret:" << ret << " rank:" << rank_id, err1);

    (void)trans_test(rank_id, rank_size, device_id, use_sdma, ipPort, memType);

    smem_uninit();
err1:
    CHECK_ACL(aclrtResetDevice(device_id));
    CHECK_ACL(aclFinalize());
    return 0;
}