* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file float.h
* \brief
*/
#pragma once
#include "interface/utils/common.h"
#include <stdio.h>
#include <stdint.h>
constexpr int SIGN_BIT_ONE = 1;
constexpr int EXP_BIT_FIVE = 5;
constexpr int FRAC_BIT_TEN = 10;
constexpr int EXP_BIT_EIGHT = 8;
constexpr int FRAC_BIT_SEVEN = 7;
namespace npu::tile_fwk {
template <typename TBase, uint32_t signBit, uint32_t expBit, uint32_t fracBit>
class Float {
TBase value;
enum class FloatExp {
bitOfByte = 8,
expZero = (1 << (expBit - 1)) - 1,
fp32ExpBit = 8,
fp32FracBit = 23,
fp32ExpZero = (1 << (fp32ExpBit - 1)) - 1,
};
void InitFromFloat(float fv)
{
void* p = &fv;
uint32_t v = *static_cast<uint32_t*>(p);
value = BaseFromFp32(v);
}
float ToFloat() const
{
uint32_t v = BaseToFp32(value);
void* p = &v;
return *static_cast<float*>(p);
}
static bool isNaN(TBase v)
{
uint32_t exp = (v >> fracBit) & BitOf(expBit);
uint32_t frac = v & BitOf(fracBit);
return (exp == BitOf(expBit)) && (frac != 0);
}
static_assert(
sizeof(TBase) * static_cast<uint32_t>(FloatExp::bitOfByte) >= signBit + expBit + fracBit,
"Invalid bit for float");
static constexpr uint32_t BitOf(uint32_t n) { return (1 << n) - 1; }
static constexpr uint32_t BaseFromFp32DivRound(uint32_t frac, uint32_t shift)
{
return (frac >> shift) + ((frac >> (shift - 1)) & 0x1);
}
static Float FromBase(TBase val)
{
Float ret;
ret.value = val;
return ret;
}
static constexpr TBase BaseFromFp32(uint32_t v32) NO_UBSAN
{
if (expBit == EXP_BIT_EIGHT) {
最接近偶数舍入法:
u32 : 0x3ed2f1aa -> 0011 1110 1101 0010 1111 0001 1010 1010
^ ^ ^
u32 format s(1) e(8) frac(23)
^ ^ ^
bf16 format s(1) e(8) frac(7)
round-nearest-to-even : 1111 0001 1010 1010
Hex half-ULP : 0x8000 0xf1aa > 0x8000 Round up (add 1 to the upper 7 bits)
u16 : 0x3ed3 -> 0011 1110 1101 0011 = 0011 1110 1101 0010 + 1
if 0xf1aa = X
X > Hex half-ULP : Round up (add 1 to the upper 7 bits)
X < Hex half-ULP : Round down (keep the original high 7 bits)
X = Hex half-ULP : Round to the nearest even number
*/
constexpr uint32_t floatToBf16FracShift = (static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit);
uint32_t fracLastBit = (v32 >> floatToBf16FracShift) & 0x1;
uint32_t hexHalfUlp = (0x1 << (static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit - 1)) - 0x1;
uint32_t roundingBias = hexHalfUlp + fracLastBit;
v32 += roundingBias;
return (v32 >> floatToBf16FracShift);
}
uint32_t sign =
(v32 >> (static_cast<uint32_t>(FloatExp::fp32ExpBit) + static_cast<uint32_t>(FloatExp::fp32FracBit))) & 0x1;
uint32_t exp32 =
(v32 >> static_cast<uint32_t>(FloatExp::fp32FracBit)) & BitOf(static_cast<uint32_t>(FloatExp::fp32ExpBit));
uint32_t frac32 = v32 & BitOf(static_cast<uint32_t>(FloatExp::fp32FracBit));
uint32_t exp = 0;
uint32_t frac = 0;
if (exp32 <
static_cast<uint32_t>(FloatExp::fp32ExpZero) + -static_cast<uint32_t>(FloatExp::expZero) + 1 + -fracBit) {
* format: 0bS X...X 0...0 == 0bS 0...0 0...01
* (fp32) (fp)
* compute: 2 ^ (X - fp32ExpZero) = 2 ^ (-expZero + 1) * 2 ^ (-fracBit)
* ==>: X = fp32ExpZero + -expZero + 1 + -fracBit
*/
exp = 0;
frac = 0;
} else if (
exp32 < static_cast<uint32_t>(FloatExp::fp32ExpZero) + 1 - static_cast<uint32_t>(FloatExp::expZero)) {
* format: 0bS X...X 0...0 == 0bS 0...1 0...00
* (fp32) (fp)
* compute: 2 ^ (X - fp32ExpZero) == 2 ^ (1 - expZero)
* ==>: X = fp32ExpZero + 1 - expZero
*
* value: 0bS X...X Y...Y = 0bS 0...0 Z...Z
* compute: (2 ^ (fp32FracBit) + Y) * 2 ^ (-fp32FracBit) * 2 ^ (X - fp32ExpZero) = Z * 2 ^ (-expZero +
* 1) * 2 ^ (-fracBit)
* ==>: Z = (2 ^ (fp32FracBit) + Y) * 2 ^ (-fp32FracBit) * 2 ^ (X - fp32ExpZero) * 2 ^ (expZero -
* 1) * 2 ^(fracBit) = (2 ^ (fp32FracBit) + Y) * 2 ^ (-fp32FracBit + X - fp32ExpZero + expZero - 1 +
* fracBit) = (2 ^ (fp32FracBit) + Y) * 2 ^ -(fp32FracBit - X + fp32ExpZero - expZero + 1 - fracBit)
*/
exp = 0;
auto shift = static_cast<uint32_t>(FloatExp::fp32FracBit) - exp32 +
static_cast<uint32_t>(FloatExp::fp32ExpZero) - static_cast<uint32_t>(FloatExp::expZero) + 1 -
fracBit;
frac = BaseFromFp32DivRound((1 << static_cast<uint32_t>(FloatExp::fp32FracBit)) | frac32, shift);
} else if (
exp32 <
static_cast<uint32_t>(FloatExp::fp32ExpZero) + BitOf(expBit) - static_cast<uint32_t>(FloatExp::expZero)) {
* format: 0bS X...X 0...0 == 0bS 1...1 0...00
* (fp32) (fp)
* compute: 2 ^ (X - fp32ExpZero) == 2 ^ (BitOf(expBit) - expZero)
* ==>: X = fp32ExpZero + BitOf(expBit) - expZero
*/
exp = exp32 - static_cast<uint32_t>(FloatExp::fp32ExpZero) + static_cast<uint32_t>(FloatExp::expZero);
frac = BaseFromFp32DivRound(frac32, static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit);
} else {
if (exp32 == BitOf(static_cast<uint32_t>(FloatExp::fp32ExpBit)) && frac32 != 0) {
exp = BitOf(expBit);
frac = frac32 >> (static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit);
} else {
exp = BitOf(expBit);
frac = 0;
}
}
return (sign << (expBit + fracBit)) | (exp << fracBit) | frac;
}
static constexpr uint32_t BaseToFp32(TBase v)
{
if (expBit == EXP_BIT_EIGHT) {
return static_cast<uint32_t>(v) << (static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit);
}
uint32_t sign = (v >> (expBit + fracBit)) & 0x1;
uint32_t exp = (v >> fracBit) & BitOf(expBit);
uint32_t frac = v & BitOf(fracBit);
uint32_t exp32 = 0;
uint32_t frac32 = 0;
if (exp == 0) {
if (frac == 0) {
exp32 = 0;
frac32 = 0;
} else {
* format: 0bS 0000 0...01XY...Y
* |<-leadingBit
* result: 0bS EEEE XY...Y0...0
* |<- fp32FracBit-1
* compute:
* shiftExp = (fracBit - 1) - (leadingBit - 1)
* shiftFrac = (fp32FracBit - 1) - (leadingBit - 1)
* EEEE = fp32ExpZero + -(expZero - 1) - shift
*/
uint32_t leadingBit =
sizeof(unsigned int) * static_cast<uint32_t>(FloatExp::bitOfByte) - __builtin_clz(frac) - 1;
uint32_t rest = frac ^ (1 << leadingBit);
uint32_t shiftExp = (fracBit - 1) - (leadingBit - 1);
uint32_t shiftFrac = (static_cast<uint32_t>(FloatExp::fp32FracBit) - 1) - (leadingBit - 1);
exp32 = static_cast<uint32_t>(FloatExp::fp32ExpZero) + -(static_cast<uint32_t>(FloatExp::expZero) - 1) -
shiftExp;
frac32 = rest << shiftFrac;
}
} else if (exp == BitOf(expBit)) {
* format: 0bS 1...1 X...X
* result: 0bS 1...1 X...X
*/
exp32 = BitOf(static_cast<uint32_t>(FloatExp::fp32ExpBit));
if (frac == 0) {
frac32 = 0;
} else {
frac32 = frac << (static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit);
}
} else {
* format: 0bS Y...Y Z...Z
* |<-fracBit
* result: 0bS E...E Z...Z0...0
* |<-fp32FracBit
*/
exp32 = exp - static_cast<uint32_t>(FloatExp::expZero) + static_cast<uint32_t>(FloatExp::fp32ExpZero);
frac32 = frac << (static_cast<uint32_t>(FloatExp::fp32FracBit) - fracBit);
}
return (sign << (static_cast<uint32_t>(FloatExp::fp32ExpBit) + static_cast<uint32_t>(FloatExp::fp32FracBit))) |
(exp32 << static_cast<uint32_t>(FloatExp::fp32FracBit)) | frac32;
}
static void PrintMetadata()
{
printf(
"expBit=%d fracBit=%d expZero=%d fp32ExpBit=%d fp32FracBit=%d fp32ExpZero=%d\n", expBit, fracBit,
static_cast<uint32_t>(FloatExp::expZero), static_cast<uint32_t>(FloatExp::fp32ExpBit),
static_cast<uint32_t>(FloatExp::fp32FracBit), static_cast<uint32_t>(FloatExp::fp32ExpZero));
}
public:
Float() : value(0) {}
template <typename T>
Float(T fv)
{
InitFromFloat(static_cast<float>(fv));
}
operator float() const { return ToFloat(); }
template <typename T>
Float operator+(T fv)
{
return Float(this->ToFloat() + static_cast<float>(fv));
}
template <typename T>
Float operator-(T fv)
{
return Float(this->ToFloat() - static_cast<float>(fv));
}
template <typename T>
Float operator*(T fv)
{
return Float(this->ToFloat() * static_cast<float>(fv));
}
template <typename T>
Float operator/(T fv)
{
return Float(this->ToFloat() / static_cast<float>(fv));
}
template <typename T>
Float& operator+=(T fv)
{
InitFromFloat(this->ToFloat() + static_cast<float>(fv));
return *this;
}
template <typename T>
Float& operator-=(T fv)
{
InitFromFloat(this->ToFloat() - static_cast<float>(fv));
return *this;
}
template <typename T>
Float& operator*=(T fv)
{
InitFromFloat(this->ToFloat() * static_cast<float>(fv));
return *this;
}
template <typename T>
Float& operator/=(T fv)
{
InitFromFloat(this->ToFloat() / static_cast<float>(fv));
return *this;
}
template <typename T>
bool operator==(T fv) const
{
TBase thisBase = value;
TBase otherBase;
if constexpr (std::is_same_v<T, Float>) {
otherBase = fv.value;
} else {
float temp = static_cast<float>(fv);
otherBase = BaseFromFp32(*reinterpret_cast<const uint32_t*>(&temp));
}
if ((thisBase & ~(1 << (expBit + fracBit))) == 0 && (otherBase & ~(1 << (expBit + fracBit))) == 0) {
return true;
}
if (isNaN(thisBase) || isNaN(otherBase)) {
return false;
}
return thisBase == otherBase;
}
template <typename T>
bool operator!=(T fv) const
{
return !(*this == fv);
}
template <typename T>
bool operator>=(T fv) const
{
return this->ToFloat() >= static_cast<float>(fv);
}
template <typename T>
bool operator<=(T fv) const
{
return this->ToFloat() <= static_cast<float>(fv);
}
template <typename T>
bool operator>(T fv) const
{
return this->ToFloat() > static_cast<float>(fv);
}
template <typename T>
bool operator<(T fv) const
{
return this->ToFloat() < static_cast<float>(fv);
}
};
using bfloat16 = Float<uint16_t, SIGN_BIT_ONE, EXP_BIT_EIGHT, FRAC_BIT_SEVEN>;
using float16 = Float<uint16_t, SIGN_BIT_ONE, EXP_BIT_FIVE, FRAC_BIT_TEN>;
}