/**
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file kernel_simt_cast_impl.h
 * \brief
 */
#ifndef IMPL_SIMT_API_CPP_DAV_C310_KERNEL_SIMT_CAST_IMPL_H
#define IMPL_SIMT_API_CPP_DAV_C310_KERNEL_SIMT_CAST_IMPL_H

#include "kernel_utils.h"
#include "impl/simt_api/cpp/dav_3510/kernel_simt_common_impl.h"
#include "impl/simt_api/cpp/dav_3510/kernel_simt_cast_sat_impl.h"

namespace AscendC {
namespace Simt {

#ifndef ASCENDC_CPU_DEBUG
#define REG_ROUND_VEC(type, func_name, len)                \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name(type &dst, type &src) \
    {                                                      \
        for (int i = 0; i < len; i++) {                    \
            dst[i] = func_name(src[i]);\
        }                                                  \
    }

#define REG_ROUND_VEC_(dst_type, src_type, d_type, s_type, func_name, len) \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name##_(dst_type &dst, src_type &src)      \
    {                                                                      \
        for (int i = 0; i < len; i++) {                                    \
            dst[i] = func_name##_<d_type, s_type>(src[i]);                 \
        }                                                                  \
    }

#define REG_CAST_IMPL_VEC(type, func_name) \
    REG_ROUND_VEC(type##1, func_name, 1)   \
    REG_ROUND_VEC(type##2, func_name, 2)   \
    REG_ROUND_VEC(type##3, func_name, 3)   \
    REG_ROUND_VEC(type##4, func_name, 4)

#define REG_CAST_IMPL_VEC_(dst_type, src_type, func_name)                      \
    REG_ROUND_VEC_(dst_type##1, src_type##1, dst_type, src_type, func_name, 1) \
    REG_ROUND_VEC_(dst_type##2, src_type##2, dst_type, src_type, func_name, 2) \
    REG_ROUND_VEC_(dst_type##3, src_type##3, dst_type, src_type, func_name, 3) \
    REG_ROUND_VEC_(dst_type##4, src_type##4, dst_type, src_type, func_name, 4)

#define REG_CAST_HF_IMPL_VEC(type, func_name) REG_ROUND_VEC(type##2, func_name, 2)

#define REG_CAST_HF_IMPL_VEC_(dst_type, src_type, func_name) \
    REG_ROUND_VEC_(dst_type##2, src_type##2, dst_type, src_type, func_name, 2)
#else
#define REG_ROUND_VEC_1(type, func_name)                   \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
    }

#define REG_ROUND_VEC_2(type, func_name)                   \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
        dst.y = func_name(src.y);                          \
    }

#define REG_ROUND_VEC_3(type, func_name)                   \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
        dst.y = func_name(src.y);                          \
        dst.z = func_name(src.z);                          \
    }

#define REG_ROUND_VEC_4(type, func_name)                   \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
        dst.y = func_name(src.y);                          \
        dst.z = func_name(src.z);                          \
        dst.w = func_name(src.w);                          \
    }

#define REG_ROUND_VEC_1_(dst_type, src_type, d_type, s_type, func_name) \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
    }

#define REG_ROUND_VEC_2_(dst_type, src_type, d_type, s_type, func_name) \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
        dst.y = func_name##_<d_type, s_type>(src.y);                    \
    }

#define REG_ROUND_VEC_3_(dst_type, src_type, d_type, s_type, func_name) \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
        dst.y = func_name##_<d_type, s_type>(src.y);                    \
        dst.z = func_name##_<d_type, s_type>(src.z);                    \
    }

#define REG_ROUND_VEC_4_(dst_type, src_type, d_type, s_type, func_name) \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
        dst.y = func_name##_<d_type, s_type>(src.y);                    \
        dst.z = func_name##_<d_type, s_type>(src.z);                    \
        dst.w = func_name##_<d_type, s_type>(src.w);                    \
    }

#define REG_CAST_IMPL_VEC(type, func_name) \
    REG_ROUND_VEC_1(type##1, func_name)    \
    REG_ROUND_VEC_2(type##2, func_name)    \
    REG_ROUND_VEC_3(type##3, func_name)    \
    REG_ROUND_VEC_4(type##4, func_name)

#define REG_CAST_IMPL_VEC_(dst_type, src_type, func_name)                     \
    REG_ROUND_VEC_1_(dst_type##1, src_type##1, dst_type, src_type, func_name) \
    REG_ROUND_VEC_2_(dst_type##2, src_type##2, dst_type, src_type, func_name) \
    REG_ROUND_VEC_3_(dst_type##3, src_type##3, dst_type, src_type, func_name) \
    REG_ROUND_VEC_4_(dst_type##4, src_type##4, dst_type, src_type, func_name)

#define REG_CAST_HF_IMPL_VEC(type, func_name) REG_ROUND_VEC_2(type##2, func_name)

#define REG_CAST_HF_IMPL_VEC_(dst_type, src_type, func_name) \
    REG_ROUND_VEC_2_(dst_type##2, src_type##2, dst_type, src_type, func_name)
#endif

#define REG_ROUND(type, func_name)                         \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst = func_name(src);          \
    }

#define REG_ROUND_(d_type, s_type, func_name)                     \
    __SIMT_DEVICE_FUNCTIONS_DECL__ inline void func_name##_(d_type &dst, s_type &src) \
    {                                                             \
        dst = func_name##_<d_type, s_type>(src);                  \
    }

#define REG_CAST_IMPL_(dst_type, src_type, func_name) REG_ROUND_(dst_type, src_type, func_name)

#define REG_CAST_VEC(round_mode)         \
    REG_CAST_IMPL_VEC(float, round_mode) \
    REG_CAST_HF_IMPL_VEC(half, round_mode)

#define REG_CAST_(round_mode)                \
    REG_CAST_IMPL_(half, float, round_mode)  \
    REG_CAST_IMPL_(int, float, round_mode)   \
    REG_CAST_IMPL_(long, float, round_mode)  \
    REG_CAST_IMPL_(bhalf, float, round_mode) \
    REG_CAST_IMPL_(float, half, round_mode)  \
    REG_CAST_IMPL_(float, int, round_mode)   \
    REG_CAST_IMPL_(float, long, round_mode)

#define REG_CAST_VEC_(round_mode)                   \
    REG_CAST_IMPL_VEC_(int, float, round_mode)      \
    REG_CAST_IMPL_VEC_(long, float, round_mode)     \
    REG_CAST_IMPL_VEC_(float, int, round_mode)      \
    REG_CAST_IMPL_VEC_(float, long, round_mode)     \
    REG_CAST_HF_IMPL_VEC_(float, half, round_mode)

#ifdef ASCENDC_CPU_DEBUG
REG_CAST_HF_IMPL_VEC_(half, float, Rint)
REG_CAST_HF_IMPL_VEC_(half, float, Floor)
REG_CAST_HF_IMPL_VEC_(half, float, Ceil)
REG_CAST_HF_IMPL_VEC_(half, float, Trunc)
REG_CAST_HF_IMPL_VEC_(half, float, CastNone)
#endif

REG_CAST_(Rint)
REG_CAST_(Floor)
REG_CAST_(Ceil)
REG_CAST_(Trunc)
REG_CAST_(CastNone)

REG_CAST_IMPL_(float, bfloat16_t, CastNone)
REG_CAST_IMPL_(float, bfloat16_t, Ceil)
REG_CAST_IMPL_(float, bfloat16_t, Floor)
REG_CAST_IMPL_(float, bfloat16_t, Trunc)
REG_CAST_IMPL_(float, bfloat16_t, Rint)

REG_CAST_VEC_(Rint)
REG_CAST_VEC_(Floor)
REG_CAST_VEC_(Ceil)
REG_CAST_VEC_(Trunc)
REG_CAST_VEC_(CastNone)

template <typename T, typename U, RoundMode roundMode>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline T CastFallback(U x)
{
    T y;
    switch (roundMode) {
#if (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
        case RoundMode::CAST_EVEN:
            Rint_(y, x);
            break;
        case RoundMode::CAST_ZERO:
            Trunc_(y, x);
            break;
#endif
        case RoundMode::CAST_FLOOR:
            Floor_(y, x);
            break;
        case RoundMode::CAST_CEIL:
            Ceil_(y, x);
            break;
        case RoundMode::CAST_NONE:
            CastNone_(y, x);
            break;
    }
    return y;
}

template <typename T, typename U, RoundMode roundMode, SatMode satMode>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline T CastImpl(U x)
{
#if defined(ASCENDC_CPU_DEBUG)
    return CastFallback<T, U, roundMode>(x);
#else
    if constexpr ((roundMode == RoundMode::CAST_EVEN || roundMode == RoundMode::CAST_ZERO) &&
                  SupportTypeSimtInternel<Tuple<U, T>, Tuple<float, int>, Tuple<int, float>, Tuple<float, int64_t>,
                              Tuple<int64_t, float>, Tuple<float, half>, Tuple<float, bfloat16_t>>) {
        return CastFallback<T, U, roundMode>(x);
    }
    if constexpr (roundMode == RoundMode::CAST_NONE &&
                  SupportTypeSimtInternel<Tuple<U, T>, Tuple<half, float>, Tuple<bfloat16_t, float>>) {
        return CastFallback<T, U, roundMode>(x);
    }
    T y;
    if constexpr (SupportTypeSimtInternel<Tuple<T, U>, Tuple<uint32_t, half>, Tuple<int32_t, half>, Tuple<uint32_t, float>,
                              Tuple<int32_t, float>, Tuple<uint64_t, float>, Tuple<int64_t, float>,
                              Tuple<uint32_t, bfloat16_t>, Tuple<int32_t, bfloat16_t>>) {
        y = CastSat<T, U, roundMode>(x);
    } else if constexpr (SupportTypeSimtInternel<Tuple<T, U>, Tuple<half, uint32_t>, Tuple<float, uint32_t>,
                                     Tuple<bfloat16_t, uint32_t>, Tuple<half, int32_t>, Tuple<float, int32_t>,
                                     Tuple<bfloat16_t, int32_t>, Tuple<float, uint64_t>, Tuple<float, int64_t>,
                                     Tuple<float, half>, Tuple<bfloat16_t, half>, Tuple<half, float>,
                                     Tuple<bfloat16_t, float>, Tuple<half, bfloat16_t>, Tuple<float, bfloat16_t>>) {
        switch (satMode) {
            case SatMode::SAT:
                y = CastSat<T, U, roundMode>(x);
                break;
            case SatMode::NO_SAT:
                y = CastNoSat<T, U, roundMode>(x);
                break;
        }
    }
    return y;
#endif
}

template <typename T>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline T RoundImpl(T x)
{
    return RoundIntrinsicsImpl(x);
}

template <typename T>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline T RintImpl(T x)
{
    return RintIntrinsicsImpl(x);
}

template <typename T>
 __SIMT_DEVICE_FUNCTIONS_DECL__ inline T FloorImpl(T x)
{
    return FloorIntrinsicsImpl(x);
}

template <typename T>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline T CeilImpl(T x)
{
    return CeilIntrinsicsImpl(x);
}

template <typename T>
__SIMT_DEVICE_FUNCTIONS_DECL__ inline T TruncImpl(T x)
{
    if (x > (T)0) {
        return FloorImpl(x);
    } else {
        return CeilImpl(x);
    }
}

}  // namespace Simt
}  // namespace AscendC
#endif  // IMPL_SIMT_API_CPP_DAV_C310_KERNEL_SIMT_CAST_IMPL_H