* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_fp16.cpp
* \brief
*/
#include "kernel_fp16.h"
namespace {
constexpr uint16_t K_MAN_BIT_LENGTH = 11;
constexpr int16_t HFEXP_NUM = 31;
constexpr uint32_t ERET_NUM = 255;
}
* @ingroup half global filed
* @brief round mode of last valid digital
*/
const enum TagFp16RoundMode ROUND_MODE = TagFp16RoundMode::K_ROUND_TO_NEAREST;
void ExtractFp16(const uint16_t& val, uint16_t& s, int16_t& e, uint16_t& m)
{
s = FP16_EXTRAC_SIGN(val);
e = FP16_EXTRAC_EXP(val);
m = FP16_EXTRAC_MAN(val);
if (e == 0) {
e = 1;
}
}
* @ingroup half static method
* @param [in] man truncated mantissa
* @param [in] shiftOut left shift bits based on ten bits
* @brief judge whether to add one to the result while converting half to
* other datatype
* @return Return true if add one, otherwise false
*/
static bool IsRoundOne(uint64_t man, uint16_t truncLen)
{
uint64_t mask0 = 0x4;
uint64_t mask1 = 0x2;
uint64_t mask2;
uint16_t shiftOut = static_cast<uint16_t>(truncLen - static_cast<uint16_t>(DimIndex::K_DIM2));
mask0 = mask0 << shiftOut;
mask1 = mask1 << shiftOut;
mask2 = mask1 - 1;
bool lastBit = ((man & mask0) > 0);
bool truncHigh = false;
bool truncLeft = false;
if (ROUND_MODE == TagFp16RoundMode::K_ROUND_TO_NEAREST) {
truncHigh = ((man & mask1) > 0);
truncLeft = ((man & mask2) > 0);
}
return (truncHigh && (truncLeft || lastBit));
}
* @ingroup half public method
* @param [in] exp exponent of half value
* @param [in] man exponent of half value
* @brief normalize half value
* @return
*/
static void Fp16Normalize(int16_t& exp, uint16_t& man)
{
if (exp >= static_cast<int16_t>(Fp16BasicParam::K_FP16_MAX_EXP)) {
exp = static_cast<int16_t>(Fp16BasicParam::K_FP16_MAX_EXP);
man = static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAX_MAN);
} else if ((exp == 0) && (man == static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAN_HIDE_BIT))) {
exp++;
man = 0;
}
}
* @ingroup half math conversion static method
* @param [in] fpVal uint16_t value of half object
* @brief Convert half to float/fp32
* @return Return float/fp32 value of fpVal which is the value of half object
*/
static float Fp16ToFloat(const uint16_t& fpVal)
{
uint16_t hfSign;
uint16_t hfMan;
int16_t hfExp;
ExtractFp16(fpVal, hfSign, hfExp, hfMan);
while ((hfMan != 0) && ((hfMan & static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAN_HIDE_BIT)) == 0)) {
hfMan <<= 1;
hfExp--;
}
uint32_t eRet;
uint32_t mRet;
uint32_t sRet = hfSign;
if (hfExp == HFEXP_NUM) {
eRet = ERET_NUM;
mRet = hfMan << (static_cast<uint32_t>(Fp32BasicParam::K_FP32_MAN_LEN) -
static_cast<uint32_t>(Fp16BasicParam::K_FP16_MAN_LEN));
uint32_t fVal = FP32_CONSTRUCTOR(sRet, eRet, mRet);
auto pRetV = reinterpret_cast<float*>(&fVal);
return *pRetV;
}
if (hfMan == 0) {
eRet = 0;
mRet = 0;
} else {
eRet = (static_cast<uint32_t>(hfExp) - static_cast<uint32_t>(Fp16BasicParam::K_FP16_EXP_BIAS)) +
static_cast<uint32_t>(Fp32BasicParam::K_FP32_EXP_BIAS);
mRet = static_cast<uint32_t>(hfMan & static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAN_MASK));
mRet = mRet << (static_cast<uint32_t>(Fp32BasicParam::K_FP32_MAN_LEN) -
static_cast<uint32_t>(Fp16BasicParam::K_FP16_MAN_LEN));
}
uint32_t fVal = FP32_CONSTRUCTOR(sRet, eRet, mRet);
auto pRetV = reinterpret_cast<float*>(&fVal);
return *pRetV;
}
static uint16_t Fp16AddCalVal(const uint16_t& sRet, int16_t eRet, uint16_t mRet, uint32_t mTrunc, uint16_t shiftOut)
{
uint16_t mMin = static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAN_HIDE_BIT) << shiftOut;
uint16_t mMax = mMin << 1;
while ((mRet < mMin) && (eRet > 0)) {
mRet = mRet << 1;
mRet += (static_cast<uint32_t>(Fp32BasicParam::K_FP32_SIGN_MASK) & mTrunc) >>
static_cast<uint16_t>(Fp32BasicParam::K_FP32_SIGN_INDEX);
mTrunc = mTrunc << 1;
eRet = eRet - 1;
}
while (mRet >= mMax) {
mTrunc = mTrunc >> 1;
mTrunc = mTrunc | (static_cast<uint32_t>(Fp32BasicParam::K_FP32_SIGN_MASK) * (mRet & 1));
mRet = mRet >> 1;
eRet = eRet + 1;
}
bool bLastBit = ((mRet & 1) > 0);
bool bTruncHigh = (ROUND_MODE == TagFp16RoundMode::K_ROUND_TO_NEAREST) &&
((mTrunc & static_cast<uint32_t>(Fp32BasicParam::K_FP32_SIGN_MASK)) > 0);
bool bTruncLeft = (ROUND_MODE == TagFp16RoundMode::K_ROUND_TO_NEAREST) &&
((mTrunc & static_cast<uint32_t>(Fp32BasicParam::K_FP32_ABS_MAX)) > 0);
mRet = ManRoundToNearest(bLastBit, bTruncHigh, bTruncLeft, mRet, shiftOut);
while (mRet >= mMax) {
mRet = mRet >> 1;
eRet = eRet + 1;
}
if ((eRet == 0) && (mRet <= mMax)) {
mRet = mRet >> 1;
}
Fp16Normalize(eRet, mRet);
uint16_t ret = FP16_CONSTRUCTOR(sRet, static_cast<uint16_t>(eRet), mRet);
return ret;
}
* @ingroup half math operator
* @param [in] v1 left operator value of half object
* @param [in] v2 right operator value of half object
* @brief Performing half addition
* @return Return half result of adding this and fp
*/
static uint16_t Fp16Add(uint16_t v1, uint16_t v2)
{
uint16_t sa;
uint16_t sb;
int16_t ea;
int16_t eb;
uint32_t ma;
uint32_t mb;
uint16_t maTmp;
uint16_t mbTmp;
uint16_t shiftOut = 0;
ExtractFp16(v1, sa, ea, maTmp);
ExtractFp16(v2, sb, eb, mbTmp);
ma = maTmp;
mb = mbTmp;
uint16_t sum;
uint16_t sRet;
if (sa != sb) {
ReverseMan(sa > 0, ma);
ReverseMan(sb > 0, mb);
sum = static_cast<uint16_t>(GetManSum(ea, ma, eb, mb));
sRet = (sum & static_cast<uint16_t>(Fp16BasicParam::K_FP16_SIGN_MASK)) >>
static_cast<uint16_t>(Fp16BasicParam::K_FP16_SIGN_INDEX);
ReverseMan(sRet > 0, ma);
ReverseMan(sRet > 0, mb);
} else {
sum = static_cast<uint16_t>(GetManSum(ea, ma, eb, mb));
sRet = sa;
}
if (sum == 0) {
shiftOut = 3;
ma = ma << shiftOut;
mb = mb << shiftOut;
}
uint32_t mTrunc = 0;
int16_t eRet = std::max(ea, eb);
uint32_t eTmp = static_cast<uint32_t>(std::abs(ea - eb));
if (ea > eb) {
mTrunc = (mb << (static_cast<uint32_t>(BitShift::K_BIT_SHIFT32) - eTmp));
mb = RightShift(mb, eTmp);
} else if (ea < eb) {
mTrunc = (ma << (static_cast<uint32_t>(BitShift::K_BIT_SHIFT32) - eTmp));
ma = RightShift(ma, eTmp);
}
auto mRet = static_cast<uint16_t>(ma + mb);
return Fp16AddCalVal(sRet, eRet, mRet, mTrunc, shiftOut);
}
half half::operator + (const half fp) const
{
uint16_t retVal = Fp16Add(val, fp.val);
half ret;
ret.val = retVal;
return ret;
}
half half::operator += (const half fp)
{
val = Fp16Add(val, fp.val);
return *this;
}
bool half::operator == (const half& fp) const
{
bool result = true;
if (FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)) {
result = true;
} else {
result = ((val & static_cast<uint16_t>(NumBitMax::K_BIT_LEN16_MAX)) ==
(fp.val & static_cast<uint16_t>(NumBitMax::K_BIT_LEN16_MAX)));
}
return result;
}
bool half::operator != (const half& fp) const
{
bool result = true;
if (FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)) {
result = false;
} else {
result = ((val & static_cast<uint16_t>(NumBitMax::K_BIT_LEN16_MAX)) !=
(fp.val & static_cast<uint16_t>(NumBitMax::K_BIT_LEN16_MAX)));
}
return result;
}
static bool CmpPosNums(const uint16_t& ea, const uint16_t& eb, const uint16_t& ma, const uint16_t& mb)
{
bool result = true;
if (ea > eb) {
result = true;
} else if (ea == eb) {
result = ma > mb;
} else {
result = false;
}
return result;
}
static bool CmpNegNums(const uint16_t& ea, const uint16_t& eb, const uint16_t& ma, const uint16_t& mb)
{
bool result = true;
if (ea < eb) {
result = true;
} else if (ea == eb) {
result = ma < mb;
} else {
result = false;
}
return result;
}
bool half::operator > (const half& fp) const
{
uint16_t sa;
uint16_t sb;
uint16_t ea;
uint16_t eb;
uint16_t ma;
uint16_t mb;
bool result = true;
sa = FP16_EXTRAC_SIGN(val);
sb = FP16_EXTRAC_SIGN(fp.val);
ea = static_cast<uint16_t>(FP16_EXTRAC_EXP(val));
eb = static_cast<uint16_t>(FP16_EXTRAC_EXP(fp.val));
ma = FP16_EXTRAC_MAN(val);
mb = FP16_EXTRAC_MAN(fp.val);
if ((sa == 0) && (sb > 0)) {
result = !(FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val));
} else if ((sa == 0) && (sb == 0)) {
result = CmpPosNums(ea, eb, ma, mb);
} else if ((sa > 0) && (sb > 0)) {
result = CmpNegNums(ea, eb, ma, mb);
} else {
result = false;
}
return result;
}
bool half::operator >= (const half& fp) const
{
bool result = true;
if (((*this) > fp) || ((*this) == fp)) {
result = true;
} else {
result = false;
}
return result;
}
bool half::operator <= (const half& fp) const
{
bool result = true;
if ((*this) > fp) {
result = false;
}
return result;
}
bool half::operator < (const half& fp) const
{
bool result = true;
if ((*this) >= fp) {
result = false;
}
return result;
}
half half::operator ++ ()
{
half one = 1.0;
val = Fp16Add(val, one.val);
return *this;
}
half half::operator ++ (int)
{
half oldBf = *this;
operator++();
return oldBf;
}
bool half::operator && (const half fp) const
{
return (val != 0) && (fp.val != 0);
}
bool half::operator || (const half fp) const
{
return (val != 0) || (fp.val != 0);
}
uint16_t half::FloatToFp16(const float& fVal) const
{
uint16_t sRet;
uint16_t mRet;
int16_t eRet;
uint32_t ef;
uint32_t mf;
const uint32_t ui32V = *(reinterpret_cast<const uint32_t*>(&fVal));
uint32_t mLenDelta;
sRet = static_cast<uint16_t>((ui32V & static_cast<uint32_t>(Fp32BasicParam::K_FP32_SIGN_MASK)) >>
static_cast<uint16_t>(Fp32BasicParam::K_FP32_SIGN_INDEX));
ef = (ui32V & static_cast<uint32_t>(Fp32BasicParam::K_FP32_EXP_MASK)) >>
static_cast<uint16_t>(Fp32BasicParam::K_FP32_MAN_LEN);
mf = (ui32V &
static_cast<uint32_t>(Fp32BasicParam::K_FP32_MAN_MASK));
mLenDelta =
static_cast<uint16_t>(Fp32BasicParam::K_FP32_MAN_LEN) - static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAN_LEN);
bool needRound = false;
if (ef > 0x8Fu) {
eRet = static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAX_EXP) - 1;
mRet = static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAX_MAN);
} else if (ef <= 0x70u) {
eRet = 0;
if (ef >= 0x67) {
mf = (mf | static_cast<uint32_t>(Fp32BasicParam::K_FP32_MAN_HIDE_BIT));
uint16_t shiftOut = static_cast<uint16_t>(Fp32BasicParam::K_FP32_MAN_LEN);
uint64_t mTmp = (static_cast<uint64_t>(mf)) << (ef - 0x67);
needRound = IsRoundOne(mTmp, shiftOut);
mRet = static_cast<uint16_t>(mTmp >> shiftOut);
if (needRound) {
mRet++;
}
} else if ((ef == 0x66) && (mf > 0)) {
mRet = 1;
} else {
mRet = 0;
}
} else {
eRet = static_cast<int16_t>(ef - 0x70u);
needRound = IsRoundOne(mf, static_cast<uint16_t>(mLenDelta));
mRet = static_cast<uint16_t>(mf >> mLenDelta);
if (needRound) {
mRet++;
}
if ((mRet & static_cast<uint16_t>(Fp16BasicParam::K_FP16_MAN_HIDE_BIT)) != 0) {
eRet++;
}
}
Fp16Normalize(eRet, mRet);
return FP16_CONSTRUCTOR(sRet, static_cast<uint16_t>(eRet), mRet);
}
uint16_t half::DoubleToFp16(const double& dVal)
{
return 0;
}
uint16_t half::Int16ToFp16(const int16_t& iVal) const
{
return iVal;
}
uint16_t half::UInt16ToFp16(const uint16_t& uiVal)
{
return uiVal;
}
uint16_t half::Int32ToFp16(const int32_t& iVal) const
{
return iVal;
}
uint16_t half::UInt32ToFp16(const uint32_t& uiVal) const
{
return uiVal;
}
half& half::operator = (const half& fp)
{
if (&fp == this) {
return *this;
}
val = fp.val;
return *this;
}
half::operator float() const
{
return Fp16ToFloat(val);
}
float half::ToFloat() const
{
return Fp16ToFloat(val);
}