* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "common/math/hif8_t.h"
#include <cmath>
#include <array>
#include <limits>
#include "securec.h"
#include "common/checker.h"
namespace {
constexpr std::array<uint8_t, 24> kHiF8ExpToMantissaWidth = {{
3, 3, 3, 3,
2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0,
static_cast<uint8_t>(-1),
}};
constexpr std::array<uint8_t, 16> kHiF8ExpToDot = {{
0b0001,
0b0010,
0b0100, 0b0100,
0b1000, 0b1000, 0b1000, 0b1000,
0b1100, 0b1100, 0b1100, 0b1100, 0b1100, 0b1100, 0b1100, 0b1100,
}};
constexpr int8_t kHiF8DmlExpMax = -16;
constexpr int8_t kHiF8DmlExpBias = 23;
constexpr uint8_t kHiF8DotOffset = 3;
constexpr int8_t kHiF8NmlExpMax = 15;
constexpr int8_t kHiF8DmlExpMin = -22;
constexpr uint8_t kHiF8SignMask = 0b10000000;
constexpr uint8_t kIeeeFp32ExpBits = 8;
constexpr uint8_t kIeeeFp32ManBits = 32 - 1 - kIeeeFp32ExpBits;
constexpr uint8_t kIeeeFp16ExpBits = 5;
constexpr uint8_t kIeeeFp16ManBits = 16 - 1 - kIeeeFp16ExpBits;
uint32_t Extract32(uint32_t u32, uint8_t width, uint8_t offset) {
return (u32 >> offset) & ((1U << width) - 1U);
}
uint32_t Deposit32(uint32_t u32, uint8_t width, uint8_t offset, uint32_t value) {
const uint32_t mask = (1U << width) - 1U;
return (u32 & ~(mask << offset)) | ((value & mask) << offset);
}
uint8_t EncodeHiF8(int8_t exp, uint8_t mantissa) {
if (exp <= kHiF8DmlExpMax) {
return static_cast<uint8_t>(exp + kHiF8DmlExpBias);
}
const uint8_t mag = static_cast<uint8_t>(exp < 0 ? -exp : exp);
auto encode_em = [](uint8_t m, uint8_t se_bit) -> uint8_t {
if (m == 0) {
return 0;
}
return Deposit32(m, 1, 7 - (__builtin_clz(m) - 24), se_bit);
};
GE_ASSERT(mag < sizeof(kHiF8ExpToDot));
return static_cast<uint8_t>(kHiF8ExpToDot[mag] << kHiF8DotOffset) |
static_cast<uint8_t>(encode_em(mag, exp < 0) << kHiF8ExpToMantissaWidth[mag]) |
mantissa;
}
enum class IeeeType : int8_t {
kZero,
kNaN,
kInf,
kOrdinary,
};
template <uint8_t exp_bits, uint8_t mantissa_bits>
IeeeType UnpackIeeeFp(uint32_t bits, uint8_t *sign_bit, int8_t *exp, uint32_t *mantissa) {
constexpr uint32_t ieee_sign_bit = 1U << (exp_bits + mantissa_bits);
constexpr int8_t bias = static_cast<int8_t>((1U << (exp_bits - 1U)) - 1U);
if ((bits & ~ieee_sign_bit) == 0) {
return IeeeType::kZero;
}
*sign_bit = (bits & ieee_sign_bit) != 0 ? kHiF8SignMask : 0U;
*exp = static_cast<int8_t>(Extract32(bits, exp_bits, mantissa_bits)) - bias;
*mantissa = Extract32(bits, mantissa_bits, 0);
if (static_cast<uint8_t>(*exp) == (1U << (exp_bits - 1U))) {
return *mantissa != 0 ? IeeeType::kNaN : IeeeType::kInf;
}
if (*exp == -bias && *mantissa != 0) {
const uint8_t shift = mantissa_bits - (32 - __builtin_clz(*mantissa)) + 1;
*exp = -bias + 1 - shift;
*mantissa = Extract32(*mantissa << shift, mantissa_bits, 0);
}
return IeeeType::kOrdinary;
}
template <uint8_t mantissa_bits>
bool RoundMantissa(uint32_t *mantissa, uint32_t probe_bit) {
const uint32_t tmp = *mantissa + probe_bit;
*mantissa = Extract32(tmp, mantissa_bits, 0);
return (tmp & (1U << mantissa_bits)) != 0;
}
template <uint8_t exp_bits, uint8_t mantissa_bits>
uint8_t HiF8FromIeeeBits(uint32_t bits) {
uint8_t sign_bit = 0;
int8_t exp = 0;
uint32_t mantissa = 0;
const auto type = UnpackIeeeFp<exp_bits, mantissa_bits>(bits, &sign_bit, &exp, &mantissa);
if (type == IeeeType::kZero || (type == IeeeType::kOrdinary && exp < kHiF8DmlExpMin - 1)) {
return 0;
} else if (type == IeeeType::kNaN) {
return 0b10000000;
} else if (type == IeeeType::kInf || (type == IeeeType::kOrdinary && exp > kHiF8NmlExpMax)) {
return sign_bit | 0b01101111U;
}
const uint8_t bits_to_discard = mantissa_bits - kHiF8ExpToMantissaWidth[static_cast<ssize_t>(std::abs(exp))];
if (RoundMantissa<mantissa_bits>(&mantissa, 1U << (bits_to_discard - 1U))) {
exp += 1;
if (exp > kHiF8NmlExpMax) {
return sign_bit | 0b01101111U;
}
}
return sign_bit | EncodeHiF8(exp, mantissa >> bits_to_discard);
}
template <typename To, typename From>
auto BitCast(const From &src) -> To {
static_assert(sizeof(To) == sizeof(From), "");
To dst = 0;
(void)memcpy_s(&dst, sizeof(To), &src, sizeof(From));
return dst;
}
}
namespace ge {
HiF8 HiF8::FromRawBits(uint8_t bits) {
HiF8 hif8;
hif8.u8_ = bits;
return hif8;
}
HiF8::HiF8(float f32) : u8_(BitsFromFp32(BitCast<uint32_t>(f32))) {}
HiF8::HiF8(fp16_t f16) : u8_(BitsFromFp16(BitCast<uint16_t>(f16))) {}
uint8_t HiF8::BitsFromFp32(uint32_t f32) {
return HiF8FromIeeeBits<kIeeeFp32ExpBits, kIeeeFp32ManBits>(f32);
}
uint8_t HiF8::BitsFromFp16(uint16_t f16) {
return HiF8FromIeeeBits<kIeeeFp16ExpBits, kIeeeFp16ManBits>(f16);
}
bool HiF8::IsNaN() const {
return u8_ == kHiF8SignMask;
}
bool HiF8::IsInf() const {
return (u8_ & static_cast<uint8_t>(~kHiF8SignMask)) == 0b01101111U;
}
HiF8::operator float() const {
constexpr uint8_t kHiF8DmlMask = 0b01111000;
constexpr uint8_t kHiF8DmlFlag = static_cast<uint8_t>(-1);
constexpr float kExpBase = 2.0F;
if (IsNaN()) {
return NAN;
} else if (IsInf()) {
return (u8_ & kHiF8SignMask) != 0 ? -std::numeric_limits<float>::infinity()
: std::numeric_limits<float>::infinity();
}
static const std::array<std::pair<uint8_t, uint8_t>, 16> kDotTable = {{
{ kHiF8DmlFlag, 4 },
{ 0, 4 },
{ 1, 3 }, { 1, 3 },
{ 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 },
{ 3, 2 }, { 3, 2 }, { 3, 2 }, { 3, 2 },
{ 4, 2 }, { 4, 2 }, { 4, 2 }, { 4, 2 },
}};
const float sign = (u8_ & kHiF8SignMask) != 0 ? -1.0F : 1.0F;
const auto pair = kDotTable[(u8_ & kHiF8DmlMask) >> kHiF8DotOffset];
const uint8_t dot = pair.first;
const uint8_t width = pair.second;
if (dot == kHiF8DmlFlag) {
const int8_t dml_exp = static_cast<int8_t>(u8_ & static_cast<uint8_t>(~(kHiF8SignMask | kHiF8DmlMask)));
if (dml_exp == 0) {
return 0.0F;
}
return sign * powf(kExpBase, static_cast<float>(dml_exp - kHiF8DmlExpBias));
}
const uint8_t mantissa_width = 7 - (dot + width);
float mantissa_value = 0.0F;
for (uint8_t n = 0; n < mantissa_width; n ++) {
if ((u8_ & static_cast<uint8_t>(1U << n)) == 0) {
continue;
}
mantissa_value += powf(kExpBase, -static_cast<float>(mantissa_width - n));
}
auto ei = [mantissa_width](uint8_t d, uint8_t raw) -> int8_t {
if (d == 0) {
return 0;
}
const uint8_t em = Extract32(raw, d, mantissa_width);
const uint8_t se = em & static_cast<uint8_t>(1U << (d - 1U));
const uint8_t mag = Deposit32(em, 1, d - 1U, 1);
return (se != 0 ? -1 : 1) * static_cast<int8_t>(mag);
};
return sign * powf(kExpBase, static_cast<float>(ei(dot, u8_))) * (1.0F + mantissa_value);
}
HiF8::operator fp16_t() const {
if (IsNaN()) {
return fp16_t(Fp16Constructor(0, kFp16MaxExp, 1));
}
fp16_t tmp;
tmp = static_cast<float>(*this);
return tmp;
}
}