/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file kernel_scalar_convert.h
 * \brief
 */
#ifndef ASCENDC_SCALAR_CONVERT_H
#define ASCENDC_SCALAR_CONVERT_H
#include "kernel_type_conversion_utils.h"

namespace AscendC {
#if defined(__NPU_ARCH__) &&                                    \
        ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) ||    \
         (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))
__aicore__ inline bfloat16_t Cast(const float& fVal)
{
    constexpr uint32_t fp32SignIdx = 31;
    constexpr uint32_t fp32ManLen = 23;
    constexpr int16_t bf16ExpMask = 0x7F80;
    constexpr uint16_t bf16AbsMax = 0x7FFFu;
    constexpr uint16_t bf16ManHideBit = 0x0080u;
    constexpr uint32_t bf16ManLen = 7;
    constexpr int16_t bf16MaxMan = 0x07F;
    constexpr int8_t bf16ManMask = 0x7F;
    constexpr uint16_t bf16MaxExp = 0xFFu;
    constexpr uint32_t fp32ManHideBit = 0x00800000u;
    constexpr uint8_t bf16SignIndex = 15;

    float fpVal = fVal;
    uint32_t u32Val = AscendC::GetScalarBitcodeValue<float, uint32_t>(fpVal); // 1:8:23bit sign:exp:man
    uint16_t u16Num = 0;
    uint16_t sign = (u32Val >> fp32SignIdx) & 1;
    if (IsFp32Inf(u32Val)) {
        u16Num = (bf16ExpMask | (sign << bf16SignIndex));
        return GetScalarBitcodeValue<uint16_t, bfloat16_t>(u16Num);
    }

    if (IsFp32Nan(u32Val)) {
        u16Num = bf16AbsMax;
        return GetScalarBitcodeValue<uint16_t, bfloat16_t>(u16Num);
    }
    uint32_t expFp = ((u32Val) & FP32_EXP_PART_MASK) >> fp32ManLen;
    uint32_t manFp = (u32Val & FP32_MAN_PART_MASK); // 23 bit mantissa dont't need to care about denormal
    manFp = (manFp | fp32ManHideBit);

    constexpr uint32_t mLenDelta = fp32ManLen - bf16ManLen;
    uint16_t eRet = static_cast<uint16_t>(expFp);
    uint16_t mRet = static_cast<uint16_t>(manFp >> mLenDelta);

    bool needRound = IsBf16NeedRound(manFp);
    if (needRound) {
        if ((eRet == (bf16MaxExp - 1)) && ((mRet & bf16ManMask) == bf16MaxMan)) {
            u16Num = (bf16ExpMask | (sign << bf16SignIndex));
            return GetScalarBitcodeValue<uint16_t, bfloat16_t>(u16Num);
        }
        ++mRet;
    }

    if (mRet > (bf16ManHideBit | bf16ManMask)) {
        ++eRet;
    }
    u16Num = ((sign) << bf16SignIndex) | ((eRet) << bf16ManLen) | ((mRet) & bf16MaxMan);
    return GetScalarBitcodeValue<uint16_t, bfloat16_t>(u16Num);
}

__aicore__ inline bfloat16_t ToBfloat16(const float& fVal)
{
    return Cast(fVal);
}

template <typename T>
__aicore__ constexpr inline float Cast(const T& bVal)
{
    static_assert(SupportType<T, bfloat16_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t, fp4x2_e1m2_t, fp4x2_e2m1_t>(),
        "ToFloat only support bfloat16_t/hifloat8_t/fp8_e5m2_t/fp8_e4m3fn_t/fp4x2_e1m2_t/fp4x2_e2m1_t data type on current device!");
    uint8_t uiNum = 0;
    uint32_t result = 0;
    float fNum = 0;

    if constexpr (IsSameType<T, bfloat16_t>::value) {
        fNum = FPTranslation::Bf16ToFp32(bVal);
    } else if constexpr (IsSameType<T, hifloat8_t>::value) {
        uiNum = GetScalarBitcodeValue<hifloat8_t, uint8_t>(bVal);
        result = FPTranslation::Hif8ToFp32(uiNum);
        return GetScalarBitcodeValue<uint32_t, float>(result);
    } else if constexpr (IsSameType<T, fp8_e5m2_t>::value) {
        uiNum = GetScalarBitcodeValue<fp8_e5m2_t, uint8_t>(bVal);
        result  = FPTranslation::Fp8e5m2ToFp32(uiNum);
        return GetScalarBitcodeValue<uint32_t, float>(result);
    } else if constexpr (IsSameType<T, fp8_e4m3fn_t>::value) {
        uiNum = GetScalarBitcodeValue<fp8_e4m3fn_t, uint8_t>(bVal);
        result = FPTranslation::Fp8e4m3ToFp32(uiNum);
        return GetScalarBitcodeValue<uint32_t, float>(result);
    } else if constexpr (IsSameType<T, fp4x2_e1m2_t>::value) {
        uiNum = GetScalarBitcodeValue<fp4x2_e1m2_t, uint8_t>(bVal);
        bfloat16_t bf16Val = FPTranslation::Fp4e1m2ToBfloat(uiNum);
        return FPTranslation::Bf16ToFp32(bf16Val);
    } else {
        uiNum = GetScalarBitcodeValue<T, uint8_t>(bVal);
        bfloat16_t bf16Val = FPTranslation::Fp4e2m1ToBfloat(uiNum);
        return FPTranslation::Bf16ToFp32(bf16Val);
    }
    return fNum;
}

template <typename T>
__aicore__ constexpr inline float ToFloat(const T& bVal)
{
    static_assert(SupportType<T, bfloat16_t, hifloat8_t, fp8_e5m2_t, fp8_e4m3fn_t, fp4x2_e1m2_t, fp4x2_e2m1_t>(),
        "Cast to float only support bfloat16_t/hifloat8_t/fp8_e5m2_t/fp8_e4m3fn_t/fp4x2_e1m2_t/fp4x2_e2m1_t data type on current device!");
    return Cast<T>(bVal);
}
#else
__aicore__ inline bfloat16_t Cast(const float& fVal)
{
    float fNum = fVal;
    union ToBfloat16Union {
        __aicore__ ToBfloat16Union() {}
        uint16_t val;
        bfloat16_t bNum;
    } bf16Union;
    union FloattoInt32Union {
        __aicore__ FloattoInt32Union() {}
        float ftmp;
        uint32_t uret;
    } int32Union;
    int32Union.ftmp = fNum;
    bf16Union.val = int32Union.uret >> BF16_TO_FP32_MAN_LEN;
    return bf16Union.bNum;
}

__aicore__ inline bfloat16_t ToBfloat16(const float& fVal)
{
    return Cast(fVal);
}

__aicore__ inline float Cast(const bfloat16_t& bVal)
{
    bfloat16_t bNum = bVal;
    union ToFloatUnion {
        __aicore__ ToFloatUnion() {}
        uint32_t val;
        float fNum;
    } floatUnion;
    union ToUint16Union {
        __aicore__ ToUint16Union() {}
        bfloat16_t uret;
        uint16_t num;
    } u16Union;
    u16Union.uret = bNum;
    floatUnion.val = u16Union.num << BF16_TO_FP32_MAN_LEN;
    return floatUnion.fNum;
}

__aicore__ inline float ToFloat(const bfloat16_t& bVal)
{
    return Cast(bVal);
}
#endif
#elif defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113))
template <typename T>
__aicore__ constexpr inline float Cast(const T& bVal)
{
    ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "unsupport Cast to float"); });
    return 0.0f;
}

template <typename T>
__aicore__ constexpr inline float ToFloat(const T& bVal)
{
    ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "unsupport ToFloat"); });
    return 0.0f;
}

__aicore__ inline bfloat16_t Cast(const float& fVal)
{
    ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "unsupport Cast to bfloat16"); });
    bfloat16_t bNum;
    return bNum;
}

__aicore__ inline bfloat16_t ToBfloat16(const float& fVal)
{
    ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "unsupport ToBfloat"); });
    bfloat16_t bNum;
    return bNum;
}

#endif
} // namespace AscendC
#endif // ASCENDC_SCALAR_CONVERT_H