* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_fp8_e5m2.cpp
* \brief
*/
#include "kernel_fp8_e5m2.h"
#include "kernel_fp32.h"
#include "kernel_utils.h"
namespace float8_e5m2 {
namespace {
#define FP8_SIGN_INDEX (7)
#define FP8_T_NAN (0x7F)
#define FP8_MAX_MAN (0x7)
constexpr int16_t FP8E5M2_EXP_MASK = 0x7C;
constexpr int16_t FP8E5M2_MAN_MASK = 0x3;
constexpr uint32_t FP8E5M2_MAN_LEN = 2;
constexpr uint32_t FP8E5M2_EXP_BIAS = 15;
#define FP8E5M2_MAN_HIDE_BIT (0x4)
#define FP8E5M2_T_MAX (0x7B)
#define FP8E5M2_MAX_EXP (0x1F)
#define FP8E5M2_MAX_MAN (0x3)
#define FP8E5M2_INF (0X7C)
#define FP8E5M2_ABS_MAKS (0X7F)
uint8_t Fp8e5m2Constructor(uint16_t s, uint16_t e, uint16_t m)
{
return (((s) << FP8_SIGN_INDEX) | ((e) << FP8E5M2_MAN_LEN) | ((m) & FP8E5M2_MAX_MAN));
}
bool Fp8e5m2IsNan(const uint16_t& x)
{
return ((((x) & FP8E5M2_EXP_MASK) == FP8E5M2_EXP_MASK) && (((x) & FP8E5M2_MAN_MASK) != 0));
}
bool Fp8e5m2IsInf(const uint16_t& x)
{
return ((x == static_cast<uint8_t>(0x7C)) || (x == static_cast<uint8_t>(0xFC))) ? true : false;
}
int8_t Fp8e5m2ExtracSign(int8_t x)
{
return (((x) >> FP8_SIGN_INDEX) & 0x1);
}
int8_t Fp8e5m2ExtracExp(int8_t x)
{
return (((x) >> FP8E5M2_MAN_LEN) & 0x1F);
}
int8_t Fp8e5m2ExtracMan(uint8_t x)
{
return ((((x) >> 0) & 0x3) | (((((x) >> FP8E5M2_MAN_LEN) & 0x1F) > 0 ? 1 : 0) * 0x4));
}
void ExtractFp8e5m2(const int8_t val, uint8_t& s, int8_t& e, uint8_t& m)
{
s = Fp8e5m2ExtracSign(val);
e = static_cast<int16_t>(Fp8e5m2ExtracExp(val));
m = Fp8e5m2ExtracMan(val);
if (e == -15) {
e = -14;
}
}
uint32_t Fp8e5m2ToFp32(const int8_t fpVal)
{
uint32_t ret = 0;
if (fpVal == 0x0) {
return 0x0;
} else if (fpVal == static_cast<int8_t>(0x80)) {
return 0x80000000;
}
if (Fp8e5m2IsNan(fpVal)) {
return FP32_NAN;
}
if (Fp8e5m2IsInf(fpVal)) {
return ((Fp8e5m2ExtracSign(fpVal) << FP32_SIGN_INDEX) | FP32_POS_INF);
}
uint8_t fp8Sign = 0;
uint8_t fp8Man = 0;
int8_t fp8Exp = 0;
ExtractFp8e5m2(fpVal, fp8Sign, fp8Exp, fp8Man);
bool isDenormal = false;
if (fp8Exp == 0) {
isDenormal = true;
}
while ((fp8Man != 0) && ((fp8Man & FP8E5M2_MAN_HIDE_BIT) == 0)) {
fp8Man <<= 1;
fp8Exp--;
}
uint32_t eRet = 0;
uint32_t mRet = 0;
uint32_t sRet = fp8Sign;
if (fp8Man == 0) {
eRet = 0;
mRet = 0;
} else {
if (isDenormal) {
eRet = (static_cast<uint64_t>(static_cast<int64_t>(fp8Exp + 1)) - FP8E5M2_EXP_BIAS) + FP32_EXP_BIAS;
} else {
eRet = (static_cast<uint64_t>(static_cast<int64_t>(fp8Exp)) - FP8E5M2_EXP_BIAS) + FP32_EXP_BIAS;
}
mRet = fp8Man << (FP32_MAN_LEN - FP8E5M2_MAN_LEN);
}
uint32_t fVal = Fp32Constructor(sRet, eRet, mRet);
ret = fVal;
return ret;
}
bool IsRoundOne(uint32_t sign, uint64_t man, uint16_t truncLen)
{
(void)sign;
if (truncLen == 0) {
return false;
}
uint64_t roundingTruncLen = 64;
uint64_t mask0 = (truncLen >= roundingTruncLen) ? 0 : 0x1ul << truncLen;
uint64_t mask1 = (truncLen > roundingTruncLen) ? 0 : 0x1ul << (truncLen - 1);
uint64_t mask2 = mask1 - 1;
bool lastBit = ((man & mask0) > 0);
bool truncHighBit = ((man & mask1) > 0);
bool truncLeft = ((man & mask2) > 0);
return (truncHighBit && (truncLeft || lastBit));
}
void Fp8e5m2Normalize(int16_t& exp, uint32_t& man)
{
if (exp >= FP8E5M2_MAX_EXP) {
exp = FP8E5M2_MAX_EXP;
if (exp > FP8E5M2_MAX_EXP) {
man = FP8E5M2_MAX_MAN - 1;
} else if ((exp == FP8E5M2_MAX_EXP) && (man >= FP8E5M2_MAX_MAN)) {
man = FP8_MAX_MAN - 1;
}
} else if (exp == 0 && man == FP8E5M2_MAN_HIDE_BIT) {
exp++;
man = 0;
}
}
}
int8_t Fp8e5m2T::FloatToFp8e5m2(const float src) const
{
uint32_t srcVal = AscendC::GetScalarBitcodeValue<float, uint32_t>(src);
if (Fp32IsInf(srcVal)) {
return ((Fp32ExtracSign(srcVal) << FP8_SIGN_INDEX) | FP8E5M2_INF);
}
if (Fp32IsNan(srcVal)) {
return FP8_T_NAN;
}
if (Fp32IsZero(srcVal)) {
return ((Fp32ExtracSign(srcVal) << FP8_SIGN_INDEX) | 0x0);
}
int8_t ret = 0;
uint32_t mRet = 0;
int16_t eRet = 0;
uint16_t shiftOut = 0;
uint16_t sRet = Fp32ExtracSign(srcVal);
uint32_t ef = Fp32ExtracExp(srcVal);
uint32_t mf = (srcVal & 0x007FFFFF);
uint32_t mLenDelta = FP32_MAN_LEN - FP8E5M2_MAN_LEN;
bool needRound = false;
if (ef == 0x6e) {
ret = (mf == 0x0) ? 0x0 : 0x1;
return ((sRet << FP8_SIGN_INDEX) | (ret & 0xff));
}
if (((ef == 0x8Eu) && (mf >= 0x700000))
|| ((ef > 0x8Eu))) {
eRet = FP8E5M2_MAX_EXP;
mRet = FP8E5M2_MAX_MAN - 1;
} else if (ef <= 0x70u) {
eRet = 0;
if (ef > 0x6e) {
mf = (mf | FP32_MAN_HIDE_BIT);
shiftOut = FP32_MAN_LEN - FP8E5M2_MAN_LEN + (0x70 - (ef)) + 1;
uint64_t mTmp = mf;
needRound = IsRoundOne(sRet, mTmp, shiftOut);
mRet = static_cast<uint16_t>(mf >> shiftOut);
if (needRound) {
mRet++;
}
} else {
mRet = 0;
}
} else {
eRet = static_cast<int16_t>(ef - 0x70u);
needRound = IsRoundOne(sRet, mf, mLenDelta);
mRet = static_cast<uint16_t>(mf >> mLenDelta);
if (needRound) {
mRet++;
}
if (((mRet & FP8E5M2_MAN_HIDE_BIT) != 0) && (needRound)) {
eRet++;
mRet = 0;
}
}
Fp8e5m2Normalize(eRet, mRet);
ret = Fp8e5m2Constructor(sRet, eRet, mRet);
return ret;
}
Fp8e5m2T::operator float() const
{
return AscendC::GetScalarBitcodeValue<uint32_t, float>(Fp8e5m2ToFp32(val));
}
float Fp8e5m2T::ToFloat() const
{
return AscendC::GetScalarBitcodeValue<uint32_t, float>(Fp8e5m2ToFp32(val));
}
}