/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License")
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file cast.cpp
 */
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "arch35/cast_struct.h"
#include "arch35/cast_impl.h"

namespace AscendcCast {
template <int t>
__aicore__ constexpr inline AscendC::MicroAPI::LoadDist ToLoadDist()
{
    if constexpr (t == CAST_MODE_REG_COPYIN_NORM) {
        return AscendC::MicroAPI::LoadDist::DIST_NORM;
    } else if constexpr (t == CAST_MODE_REG_COPYIN_DS_B8) {
        return AscendC::MicroAPI::LoadDist::DIST_DS_B8;
    } else if constexpr (t == CAST_MODE_REG_COPYIN_DS_B16) {
        return AscendC::MicroAPI::LoadDist::DIST_DS_B16;
    } else if constexpr (t == CAST_MODE_REG_COPYIN_UNPACK_B8) {
        return AscendC::MicroAPI::LoadDist::DIST_UNPACK_B8;
    } else if constexpr (t == CAST_MODE_REG_COPYIN_UNPACK_B16) {
        return AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16;
    } else if constexpr (t == CAST_MODE_REG_COPYIN_UNPACK_B32) {
        return AscendC::MicroAPI::LoadDist::DIST_UNPACK_B32;
    } else if constexpr (t == CAST_MODE_REG_COPYIN_UNPACK4_B8) {
        return AscendC::MicroAPI::LoadDist::DIST_UNPACK4_B8;
    }

    return AscendC::MicroAPI::LoadDist::DIST_NORM;
}

template <int t>
__aicore__ constexpr inline AscendC::MicroAPI::StoreDist ToStoreDist()
{
    if constexpr (t == CAST_MODE_REG_COPYOUT_NORM) {
        return AscendC::MicroAPI::StoreDist::DIST_NORM;
    } else if constexpr (t == CAST_MODE_REG_COPYOUT_PACK_B16) {
        return AscendC::MicroAPI::StoreDist::DIST_PACK_B16;
    } else if constexpr (t == CAST_MODE_REG_COPYOUT_PACK_B32) {
        return AscendC::MicroAPI::StoreDist::DIST_PACK_B32;
    } else if constexpr (t == CAST_MODE_REG_COPYOUT_PACK_B64) {
        return AscendC::MicroAPI::StoreDist::DIST_PACK_B64;
    } else if constexpr (t == CAST_MODE_REG_COPYOUT_PACK4_B32) {
        return AscendC::MicroAPI::StoreDist::DIST_PACK4_B32;
    }
    return AscendC::MicroAPI::StoreDist::DIST_NORM;
}

template <int t>
__aicore__ constexpr inline AscendC::RoundMode ToRoundMode()
{
    if constexpr (t == CAST_ROUND_MODE_UNKNOWN) {
        return AscendC::RoundMode::UNKNOWN;
    } else if constexpr (t == CAST_ROUND_MODE_NONE) {
        return AscendC::RoundMode::CAST_NONE;
    } else if constexpr (t == CAST_ROUND_MODE_RINT) {
        return AscendC::RoundMode::CAST_RINT;
    } else if constexpr (t == CAST_ROUND_MODE_FLOOR) {
        return AscendC::RoundMode::CAST_FLOOR;
    } else if constexpr (t == CAST_ROUND_MODE_CEIL) {
        return AscendC::RoundMode::CAST_CEIL;
    } else if constexpr (t == CAST_ROUND_MODE_ROUND) {
        return AscendC::RoundMode::CAST_ROUND;
    } else if constexpr (t == CAST_ROUND_MODE_TRUNC) {
        return AscendC::RoundMode::CAST_TRUNC;
    } else if constexpr (t == CAST_ROUND_MODE_ODD) {
        return AscendC::RoundMode::CAST_ODD;
    } else if constexpr (t == CAST_ROUND_MODE_HYBRID) {
        return AscendC::RoundMode::CAST_HYBRID;
    }
    return AscendC::RoundMode::CAST_NONE;
}
}

template<int id>
__global__ __aicore__ void cast(GM_ADDR x, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    AscendC::TPipe pipe;
    if constexpr (!(GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::isValid_)) {
        return;
    }
    constexpr int templateId = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::id_;
    constexpr int mapDtypeX = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::srcMapType_;
    constexpr int mapDtypeMid = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::midType_;
    constexpr int mapDtypeY = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::dstMapType_;
    constexpr int castMode1 = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::castMode1_;
    constexpr int castMode2 = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::castMode2_;
    constexpr int regCopyInMode = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::regCopyInMode_;
    constexpr int regCopyOutMode = GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::regCopyOutMode_;
    using dTypeX = typename GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::srcType;
    using dTypeY = typename GetCastPolicy<ORIG_DTYPE_X, ORIG_DTYPE_Y>::dstType;

    using mapInType = typename AscendcCast::TypeGetTool<mapDtypeX>::type;
    using mapMidType = typename AscendcCast::TypeGetTool<mapDtypeMid>::type;
    using mapOutType = typename AscendcCast::TypeGetTool<mapDtypeY>::type;
    if constexpr (templateId == CAST_TEMPLATE_DIRECT_CAST) {
        AscendcCast::CastDirect<mapInType, mapOutType> op;
        constexpr AscendC::RoundMode rMode1 = AscendcCast::ToRoundMode<castMode1>();
        op.Init(x, y, rMode1, &tilingData, &pipe);
        op.Process();
    } else if constexpr (templateId == CAST_TEMPLATE_DST_BOOL) {
#if ORIG_DTYPE_Y == DT_BOOL
        AscendcCast::CastDstBool<dTypeX> op;
        op.Init(x, y, &tilingData, &pipe);
        op.Process();
#endif
    } else if constexpr (templateId == CAST_TEMPLATE_THROUGH) {
        AscendcCast::CastThrough<dTypeY> op;
        op.Init(x, y, &tilingData, &pipe);
        op.Process();
    } else if constexpr (templateId == CAST_TEMPLATE_SRC_UINT1) {
        AscendcCast::CastUint1<dTypeY> op;
        op.Init(x, y, &tilingData, &pipe);
        op.Process();
    } else if constexpr (templateId == CAST_TEMPLATE_TWO_CAST) {
        AscendcCast::CastTwo<mapInType, mapMidType, mapOutType> op;
        constexpr AscendC::RoundMode rMode1 = AscendcCast::ToRoundMode<castMode1>();
        constexpr AscendC::RoundMode rMode2 = AscendcCast::ToRoundMode<castMode2>();
        op.Init(x, y, rMode1, rMode2, &tilingData, &pipe);
        op.Process();
    } else if constexpr (templateId == CAST_TEMPLATE_MIRCRO_INOUT ||
        templateId == CAST_TEMPLATE_MIRCRO_CAST || templateId == CAST_TEMPLATE_MIRCRO_CAST_INTER ||
        templateId == CAST_TEMPLATE_MIRCRO_CAST_DEINTER || templateId == CAST_TEMPLATE_MIRCRO_CAST_CAST_DEINTER ||
        templateId == CAST_TEMPLATE_MIRCRO_CAST_CAST || templateId == CAST_TEMPLATE_MIRCRO_CAST_INTER_CAST ||
        templateId == CAST_TEMPLATE_MIRCRO_CAST_DEINTER_CAST || templateId == CAST_TEMPLATE_MIRCRO_CAST_CAST_DEINTER_CAST ||
        templateId == CAST_TEMPLATE_MIRCRO_CAST_INTER_CAST_CAST || templateId == CAST_TEMPLATE_MIRCRO_DEINTER_SHIFT) {
        constexpr AscendC::MicroAPI::LoadDist ldDist = AscendcCast::ToLoadDist<regCopyInMode>();
        constexpr AscendC::MicroAPI::StoreDist stDist = AscendcCast::ToStoreDist<regCopyOutMode>();
        constexpr AscendC::RoundMode rMode1 = AscendcCast::ToRoundMode<castMode1>();
        constexpr AscendC::RoundMode rMode2 = AscendcCast::ToRoundMode<castMode2>();
        AscendcCast::CastMicro<templateId, dTypeX, dTypeY, mapInType, mapMidType, mapOutType,
            ldDist, stDist, rMode1, rMode2> op;
        op.Init(x, y, &tilingData, &pipe);
        op.Process();
    }
    return;
}