/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file hccl_tiling.cpp
 * \brief
 */

#include "tiling/platform/platform_ascendc.h"
#include "include/adv_api/hccl/hccl_tiling.h"
#include "include/adv_api/hccl/hccl_common.h"
#include "securec.h"
#include "../../detail/host_log.h"
#include "include/adv_api/hccl/internal/hccl_msg.h"
#include "include/adv_api/hccl/internal/hccl_tiling_msg.h"
#include "tiling/platform/platform_ascendc.h"
#include "hccl_symbol_loader.h"

using namespace std;
using namespace HcclApi;

namespace AscendC {
namespace {
void PrintMc2InitTiling(const Mc2InitTilingInner& tiling)
{
    TILING_LOG_DEBUG("Mc2InitTiling msg begin.");
    TILING_LOG_DEBUG("Mc2InitTiling msg version:%u", tiling.version);
    TILING_LOG_DEBUG("Mc2InitTiling msg mc2HcommCnt:%u", tiling.mc2HcommCnt);
    for (uint32_t i = 0; i < tiling.mc2HcommCnt; i++) {
        TILING_LOG_DEBUG("Mc2InitTiling msg offset%u:%u", i, tiling.offset[i]);
    }
    TILING_LOG_DEBUG("Mc2InitTiling msg debugMode:%u", tiling.debugMode);
    TILING_LOG_DEBUG("Mc2InitTiling msg preparePosition:%u", tiling.preparePosition);
    TILING_LOG_DEBUG("Mc2InitTiling msg queueNum:%u", tiling.queueNum);
    TILING_LOG_DEBUG("Mc2InitTiling msg commBlockNum:%u", tiling.commBlockNum);
    TILING_LOG_DEBUG("Mc2InitTiling msg devType:%u", tiling.devType);
    TILING_LOG_DEBUG("Mc2InitTiling msg end.");
}

void PrintMc2CcTiling(const Mc2CcTilingInner& tiling)
{
    TILING_LOG_DEBUG("Mc2CcTiling msg begin.");
    TILING_LOG_DEBUG("Mc2CcTiling msg skipLocalRankCopy:%u", tiling.skipLocalRankCopy);
    TILING_LOG_DEBUG("Mc2CcTiling msg skipBufferWindowCopy:%u", tiling.skipBufferWindowCopy);
    TILING_LOG_DEBUG("Mc2CcTiling msg stepSize:%u", tiling.stepSize);
    TILING_LOG_DEBUG("Mc2CcTiling msg version:%u", tiling.version);
    TILING_LOG_DEBUG("Mc2CcTiling msg groupName:%s", tiling.groupName);
    TILING_LOG_DEBUG("Mc2CcTiling msg algConfig:%s", tiling.algConfig);
    TILING_LOG_DEBUG("Mc2CcTiling msg opType:%u", tiling.opType);
    TILING_LOG_DEBUG("Mc2CcTiling msg reduceType:%u", tiling.reduceType);
    TILING_LOG_DEBUG("Mc2CcTiling msg dstDataType:%u", tiling.dstDataType);
    TILING_LOG_DEBUG("Mc2CcTiling msg srcDataType:%u", tiling.srcDataType);
    TILING_LOG_DEBUG("Mc2CcTiling msg commEngine:%u", tiling.commEngine);
    TILING_LOG_DEBUG("Mc2CcTiling msg end.");
}

uint32_t SetDevType(Mc2InitTilingInner* tilingInner)
{
    auto getSocVerFunc =
        HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>("libruntime.so", "rtGetSocVersion");
    ASCENDC_HOST_ASSERT(getSocVerFunc != nullptr, return EXIT_FAILURE, "Failed to get soc version.");
    char socVersion[50];
    getSocVerFunc(socVersion, sizeof(socVersion));
    std::string devType = std::string(socVersion);
    if (devType.find("Ascend910B") != std::string::npos) {
        tilingInner->devType = static_cast<uint8_t>(platform_ascendc::SocVersion::ASCEND910B);
    } else {
        tilingInner->devType = UINT8_MAX;
    }
    return EXIT_SUCCESS;
}

uint32_t UpdateMc2InitTiling(uint64_t initTilingAddr, uint64_t ccTilingAddr)
{
    Mc2InitTilingInner* tilingInner = reinterpret_cast<Mc2InitTilingInner*>(static_cast<uintptr_t>(initTilingAddr));
    uint32_t& cnt = tilingInner->mc2HcommCnt;
    ASCENDC_HOST_ASSERT(
        cnt < MAX_CC_TILING_NUM, return EXIT_FAILURE, "mc2HcommCnt(%u) must be less than or equal %u.", cnt + 1U,
        MAX_CC_TILING_NUM);
    tilingInner->offset[cnt] = static_cast<uint32_t>(ccTilingAddr - initTilingAddr);
    TILING_LOG_INFO(
        "Update Mc2InitTiling, index:%u, offset:%u, initTilingAddr:%#lx, ccTilingAddr:%#lx.", cnt,
        tilingInner->offset[cnt], initTilingAddr, ccTilingAddr);
    ++cnt;
    PrintMc2InitTiling(*tilingInner);
    return EXIT_SUCCESS;
}
} // namespace

Mc2CcTilingConfig::Mc2CcTilingConfig(
    const std::string& groupName, uint32_t opType, const std::string& algConfig, uint32_t reduceType,
    uint8_t dstDataType, uint8_t srcDataType, uint8_t commEngine)
{
    impl_.groupName_ = groupName;
    impl_.opType_ = opType;
    impl_.algConfig_ = algConfig;
    impl_.reduceType_ = reduceType;
    impl_.srcDataType_ = srcDataType;
    impl_.dstDataType_ = dstDataType;
    impl_.commEngine_ = commEngine;
    TILING_LOG_INFO(
        "Init groupName_:%s, opType_:%u, algConfig_:%s, reduceType_:%u, srcDataType_:%u, dstDataType_:%u, "
        "commEngine_:%u.",
        impl_.groupName_.c_str(), impl_.opType_, impl_.algConfig_.c_str(), impl_.reduceType_, impl_.srcDataType_,
        impl_.dstDataType_, impl_.commEngine_);
}

Mc2CcTilingConfig::~Mc2CcTilingConfig() = default;

uint32_t Mc2CcTilingConfig::GetTiling(::Mc2InitTiling& tiling)
{
    // It is not a reduce type or a reduce type, and the reduce type is valid
    const bool reduceFlag =
        (impl_.opType_ == static_cast<uint8_t>(HcclCMDType::HCCL_CMD_ALLREDUCE) ||
         impl_.opType_ == static_cast<uint8_t>(HcclCMDType::HCCL_CMD_REDUCE) ||
         impl_.opType_ == static_cast<uint8_t>(HcclCMDType::HCCL_CMD_REDUCE_SCATTER));
    ASCENDC_HOST_ASSERT(
        !reduceFlag || impl_.reduceType_ < HCCL_REDUCE_RESERVED, return EXIT_FAILURE,
        "when opType(%u) is reduce, reduceType must be less than %u.", impl_.opType_,
        static_cast<uint8_t>(HCCL_REDUCE_RESERVED));

    impl_.initTilingAddr_ = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(&tiling));
    Mc2InitTilingInner* tilingInner = reinterpret_cast<Mc2InitTilingInner*>(&tiling);
    tilingInner->version = INIT_TILING_VERSION;
    tilingInner->mc2HcommCnt = 0;
    tilingInner->debugMode = impl_.debugMode_;
    tilingInner->preparePosition = 0;
    tilingInner->queueNum = impl_.queueNum_;
    tilingInner->commBlockNum = impl_.commBlockNum_;
    for (uint32_t i = 0; i < MAX_CC_TILING_NUM; i++) {
        tilingInner->offset[i] = 0;
    }
    ASCENDC_HOST_ASSERT(SetDevType(tilingInner) == EXIT_SUCCESS, return EXIT_FAILURE, "SetDevType failed.");

    TILING_LOG_DEBUG("Mc2InitTiling addr %#lx", impl_.initTilingAddr_);
    PrintMc2InitTiling(*tilingInner);
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::GetTiling(::Mc2CcTiling& tiling)
{
    ASCENDC_HOST_ASSERT(impl_.initTilingAddr_ != 0, return EXIT_FAILURE, "must be set Mc2InitTiling first.");
    ASCENDC_HOST_ASSERT(
        impl_.groupName_.length() <= GROUP_NAME_SIZE, return EXIT_FAILURE,
        "groupName(%s) must be less than or equal %u.", impl_.groupName_.c_str(), GROUP_NAME_SIZE);
    ASCENDC_HOST_ASSERT(
        impl_.algConfig_.length() <= ALG_CONFIG_SIZE, return EXIT_FAILURE,
        "algConfig(%s) must be less than or equal %u.", impl_.algConfig_.c_str(), ALG_CONFIG_SIZE);
    bool isValid = (0 < impl_.opType_) && (impl_.opType_ < static_cast<uint8_t>(HcclCMDType::HCCL_CMD_ALL));
    ASCENDC_HOST_ASSERT(
        isValid, return EXIT_FAILURE, "opType(%u) must be less than %u, and not 0.", impl_.opType_,
        static_cast<uint8_t>(HcclCMDType::HCCL_CMD_ALL));
    ASCENDC_HOST_ASSERT(
        impl_.reduceType_ < HCCL_REDUCE_RESERVED, return EXIT_FAILURE, "reduceType(%u) must be less than %u.",
        impl_.reduceType_, static_cast<uint8_t>(HCCL_REDUCE_RESERVED));
    ASCENDC_HOST_ASSERT(
        impl_.dstDataType_ < HCCL_DATA_TYPE_RESERVED, return EXIT_FAILURE, "dstDataType(%u) must be less than %u.",
        impl_.dstDataType_, static_cast<uint8_t>(HCCL_DATA_TYPE_RESERVED));
    ASCENDC_HOST_ASSERT(
        impl_.srcDataType_ < HCCL_DATA_TYPE_RESERVED, return EXIT_FAILURE, "srcDataType(%u) must be less than %u.",
        impl_.srcDataType_, static_cast<uint8_t>(HCCL_DATA_TYPE_RESERVED));

    Mc2CcTilingInner* tilingInner = reinterpret_cast<Mc2CcTilingInner*>(&tiling);
    tilingInner->skipLocalRankCopy = impl_.skipLocalRankCopy_;
    tilingInner->skipBufferWindowCopy = impl_.skipBufferWindowCopy_;
    tilingInner->stepSize = impl_.stepSize_;
    tilingInner->version = 1U;
    (void)memset_s(tilingInner->reserved, sizeof(tilingInner->reserved), 0, sizeof(tilingInner->reserved));
    auto ret = strcpy_s(tilingInner->groupName, sizeof(tilingInner->groupName), impl_.groupName_.c_str());
    ASCENDC_HOST_ASSERT(ret == EOK, return EXIT_FAILURE, "groupName(%s) copy failed.", impl_.groupName_.c_str());
    ret = strcpy_s(tilingInner->algConfig, sizeof(tilingInner->algConfig), impl_.algConfig_.c_str());
    ASCENDC_HOST_ASSERT(ret == EOK, return EXIT_FAILURE, "algConfig(%s) copy failed.", impl_.algConfig_.c_str());
    tilingInner->opType = impl_.opType_;
    tilingInner->reduceType = impl_.reduceType_;
    tilingInner->srcDataType = impl_.srcDataType_;
    tilingInner->dstDataType = impl_.dstDataType_;
    tilingInner->commEngine = impl_.commEngine_;

    uint64_t ccTilingAddr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(&tiling));
    ASCENDC_HOST_ASSERT(
        UpdateMc2InitTiling(impl_.initTilingAddr_, ccTilingAddr) == EXIT_SUCCESS, return EXIT_FAILURE,
        "UpdateMc2InitTiling failed.");
    PrintMc2CcTiling(*tilingInner);

    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetOpType(uint32_t opType)
{
    bool isValid = (0 < opType) && (opType < static_cast<uint8_t>(HcclCMDType::HCCL_CMD_ALL));
    ASCENDC_HOST_ASSERT(
        isValid, return EXIT_FAILURE, "opType(%u) must be less than %u, and not 0.", opType,
        static_cast<uint8_t>(HcclCMDType::HCCL_CMD_ALL));
    impl_.opType_ = opType;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetGroupName(const std::string& groupName)
{
    ASCENDC_HOST_ASSERT(
        groupName.length() <= GROUP_NAME_SIZE, return EXIT_FAILURE, "groupName(%s) must be less than or equal %u.",
        groupName.c_str(), GROUP_NAME_SIZE);
    impl_.groupName_ = groupName;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetAlgConfig(const std::string& algConfig)
{
    ASCENDC_HOST_ASSERT(
        algConfig.length() <= ALG_CONFIG_SIZE, return EXIT_FAILURE, "algConfig(%s) must be less than or equal %u.",
        algConfig.c_str(), ALG_CONFIG_SIZE);
    impl_.algConfig_ = algConfig;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetReduceType(uint32_t reduceType, uint8_t dstDataType, uint8_t srcDataType)
{
    ASCENDC_HOST_ASSERT(
        reduceType < HCCL_REDUCE_RESERVED, return EXIT_FAILURE, "reduceType(%u) must be less than %u.", reduceType,
        static_cast<uint8_t>(HCCL_REDUCE_RESERVED));
    ASCENDC_HOST_ASSERT(
        dstDataType < HCCL_DATA_TYPE_RESERVED, return EXIT_FAILURE, "dstDataType(%u) must be less than %u.",
        dstDataType, static_cast<uint8_t>(HCCL_DATA_TYPE_RESERVED));
    ASCENDC_HOST_ASSERT(
        srcDataType < HCCL_DATA_TYPE_RESERVED, return EXIT_FAILURE, "srcDataType(%u) must be less than %u.",
        srcDataType, static_cast<uint8_t>(HCCL_DATA_TYPE_RESERVED));
    impl_.reduceType_ = reduceType;
    impl_.srcDataType_ = srcDataType;
    impl_.dstDataType_ = dstDataType;
    return EXIT_SUCCESS;
}


uint32_t Mc2CcTilingConfig::SetStepSize(uint8_t stepSize)
{
    auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance();
    auto npuArch = ascendcPlatform->GetCurNpuArch();
    if (npuArch == NpuArch::DAV_2201) {
        // This API is only supported on Atlas A3
        impl_.stepSize_ = stepSize;
        return EXIT_SUCCESS;
    } else {
        return EXIT_FAILURE;
    }
}

uint32_t Mc2CcTilingConfig::SetSkipLocalRankCopy(uint8_t skipLocalRankCopy)
{
    ASCENDC_HOST_ASSERT(
        skipLocalRankCopy <= 1, return EXIT_FAILURE, "skipLocalRankCopy(%u) must be less than or equal 1.",
        skipLocalRankCopy);
    impl_.skipLocalRankCopy_ = skipLocalRankCopy;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetSkipBufferWindowCopy(uint8_t skipBufferWindowCopy)
{
    ASCENDC_HOST_ASSERT(
        skipBufferWindowCopy <= 2, return EXIT_FAILURE, "skipBufferWindowCopy(%u) must be less than or equal 2.",
        skipBufferWindowCopy);
    impl_.skipBufferWindowCopy_ = skipBufferWindowCopy;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetDebugMode(uint8_t debugMode)
{
    bool ret = (1 <= debugMode && debugMode <= 4) || (debugMode >= 250);
    ASCENDC_HOST_ASSERT(ret, return EXIT_FAILURE, "debugMode(%u) only support [1,4] or [250,255].", debugMode);
    impl_.debugMode_ = debugMode;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetQueueNum(uint16_t num)
{
    impl_.queueNum_ = num;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetCommBlockNum(uint16_t num)
{
    impl_.commBlockNum_ = num;
    return EXIT_SUCCESS;
}

uint32_t Mc2CcTilingConfig::SetCommEngine(uint8_t commEngine)
{
    impl_.commEngine_ = commEngine;
    return EXIT_SUCCESS;
}
} // namespace AscendC