/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_COMMON_FP16_T_H_
#define GE_COMMON_FP16_T_H_
#include <algorithm>
#include <cmath>
#include <cstdint>
#include "graph/types.h"
#include "graph/def_types.h"
namespace ge {
enum class DimIndex : uint32_t {
kDim0 = 0,
kDim1,
kDim2,
kDim3,
kDim4,
kDim5,
kDim6,
kDim7,
kDim8,
kDim9,
kDim10,
kDim11,
kDim12,
kDim13,
kDim14,
kDim15,
kDim16,
};
enum class BitShift : uint32_t{
kBitShift2 = 2,
kBitShift3 = 3,
kBitShift4 = 4,
kBitShift5 = 5,
kBitShift6 = 6,
kBitShift7 = 7,
kBitShift8 = 8,
kBitShift9 = 9,
kBitShift10 = 10,
kBitShift11 = 11,
kBitShift12 = 12,
kBitShift13 = 13,
kBitShift14 = 14,
kBitShift15 = 15,
kBitShift16 = 16,
kBitShift20 = 20,
kBitShift24 = 24,
kBitShift27 = 27,
kBitShift28 = 28,
kBitShift31 = 31,
kBitShift32 = 32,
kBitShift36 = 36,
kBitShift40 = 40,
kBitShift44 = 44,
kBitShift48 = 48,
kBitShift52 = 52,
kBitShift56 = 56,
kBitShift59 = 59,
kBitShift60 = 60,
kBitShift63 = 63,
kBitShift64 = 64,
kBitShift128 = 128,
kBitShift255 = 255,
kBitShift256 = 256,
kBitShift512 = 512,
kBitShift768 = 768,
kBitShift784 = 784,
kBitShift1020 = 1020,
kBitShift1024 = 1024,
kBitShift3136 = 3136,
kBitShift4096 = 4096,
kBitShift6144 = 6144,
kBitShift10240 = 10240,
kBitShift65536 = 65536
};
/// @ingroup fp16 basic parameter
/// @brief fp16 exponent bias
constexpr uint16_t kFp16ExpBias = 15U;
/// @ingroup fp16 basic parameter
/// @brief the mantissa bit length of fp16 is 10
constexpr uint16_t kFp16ManLen = 10U;
/// @ingroup fp16 basic parameter
/// @brief bit index of sign in fp16
constexpr uint16_t kFp16SignIndex = 15U;
/// @ingroup fp16 basic parameter
/// @brief sign mask of fp16 (1 00000 00000 00000)
constexpr uint16_t kFp16SignMask = 0x8000U;
/// @ingroup fp16 basic parameter
/// @brief exponent mask of fp16 ( 11111 00000 00000)
constexpr uint16_t kFp16ExpMask = 0x7C00U;
/// @ingroup fp16 basic parameter
/// @brief mantissa mask of fp16 ( 11111 11111)
constexpr uint16_t kFp16ManMask = 0x03FFU;
/// @ingroup fp16 basic parameter
/// @brief hide bit of mantissa of fp16( 1 00000 00000)
constexpr uint16_t kFp16ManHideBit = 0x0400U;
/// @ingroup fp16 basic parameter
/// @brief maximum value (0111 1011 1111 1111)
constexpr uint16_t kFp16Max = 0x7BFFU;
/// @ingroup fp16 basic parameter
/// @brief absolute maximum value (0111 1111 1111 1111)
constexpr uint16_t kFp16AbsMax = 0x7FFFU;
/// @ingroup fp16 basic parameter
/// @brief maximum exponent value of fp16 is 15(11111)
constexpr uint16_t kFp16MaxExp = 0x001FU;
/// @ingroup fp16 basic parameter
/// @brief maximum mantissa value of fp16(11111 11111)
constexpr uint16_t kFp16MaxMan = 0x03FFU;
/// @ingroup fp16 basic operator
/// @brief get sign of fp16
inline uint16_t Fp16ExtracSign(const uint16_t x) {
return ((x >> 15U) & 1U);
}
/// @ingroup fp16 basic operator
/// @brief get exponent of fp16
inline uint16_t Fp16ExtracExp(const uint16_t x) {
return ((x >> 10U) & kFp16MaxExp);
}
/// @ingroup fp16 basic operator
/// @brief get mantissa of fp16
inline uint16_t Fp16ExtracMan(const uint16_t x) {
const uint8_t result = ((static_cast<uint32_t>(x >> 10U) & 0x1FU) > 0U) ? 1U : 0U;
return static_cast<uint16_t>((static_cast<uint32_t>(x >> 0U) & 0x3FFU) | (result * 0x400U));
}
/// @ingroup fp16 basic operator
/// @brief constructor of fp16 from sign exponent and mantissa
inline uint16_t Fp16Constructor(const uint16_t s, const uint16_t e, const uint16_t m) {
return (static_cast<uint16_t>(s << kFp16SignIndex) | static_cast<uint16_t>(e << kFp16ManLen) | (m & kFp16MaxMan));
}
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is zero
inline bool Fp16IsZero(const uint16_t x) {
return ((x & kFp16AbsMax) == 0U);
}
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is a denormalized value
inline bool Fp16IsDenorm(const uint16_t x) {
return ((x & kFp16ExpMask) == 0U);
}
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is invalid
inline bool Fp16IsInvalid(const uint16_t x) {
return ((x & kFp16ExpMask) == kFp16ExpMask);
}
/// @ingroup fp32 basic parameter
/// @brief fp32 exponent bias
constexpr uint32_t kFp32ExpBias = 127U;
/// @ingroup fp32 basic parameter
/// @brief the mantissa bit length of float/fp32 is 23
constexpr uint16_t kFp32ManLen = 23U;
/// @ingroup fp32 basic parameter
/// @brief bit index of sign in float/fp32
constexpr uint16_t kFp32SignIndex = 31U;
/// @ingroup fp32 basic parameter
/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000)
constexpr uint32_t kFp32SignMask = 0x80000000U;
/// @ingroup fp32 basic parameter
/// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000)
constexpr uint32_t kFp32ExpMask = 0x7F800000U;
/// @ingroup fp32 basic parameter
/// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111)
constexpr uint32_t kFp32ManMask = 0x007FFFFFU;
/// @ingroup fp32 basic parameter
/// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000)
constexpr uint32_t kFp32ManHideBit = 0x00800000U;
/// @ingroup fp32 basic parameter
/// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111)
constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFU;
/// @ingroup fp32 basic operator
/// @brief constructor of fp32 from sign exponent and mantissa
inline uint32_t Fp32Constructor(const uint32_t s, const uint32_t e, const uint32_t m) {
return ((s << kFp32SignIndex) | (e << kFp32ManLen) | (m & 0x7FFFFFU));
}
/// @ingroup fp64 basic parameter
/// @brief fp64 exponent bias
constexpr uint64_t kFp64ExpBias = 1023U;
/// @ingroup fp64 basic parameter
/// @brief the mantissa bit length of double/fp64 is 52
constexpr uint16_t kFp64ManLen = 52U;
/// @ingroup fp64 basic parameter
/// @brief bit index of sign in double/fp64 is 63
constexpr uint16_t kFp64SignIndex = 63U;
/// @ingroup fp64 basic parameter
/// @brief sign mask of fp64 (1 000 (total 63bits 0))
constexpr uint64_t kFp64SignMask = 0x8000000000000000ULL;
/// @ingroup fp64 basic parameter
/// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0))
constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000ULL;
/// @ingroup fp64 basic parameter
/// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1))
constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFULL;
/// @ingroup fp64 basic parameter
/// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0))
constexpr uint64_t kFp64ManHideBit = 0x0010000000000000ULL;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int8_t (0111 1111)
constexpr int8_t kInt8Max = 0x7F;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 8 bits length (1111 111)
constexpr uint8_t kBitLen8Max = 0xFFU;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int16_t (0111 1111 1111 1111)
constexpr int16_t kInt16Max = 0x7FFF;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111)
constexpr uint16_t kBitLen16Max = 0xFFFFU;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111)
constexpr uint32_t kBitLen32Max = 0xFFFFFFFFU;
/// @ingroup fp16_t enum
/// @brief round mode of last valid digital
enum class TagFp16RoundMode : uint32_t {
kRoundToNearest = 0, // < round to nearest even
kRoundByTruncated, // < round by truncated
kRoundModeReserved,
};
/// @ingroup fp16_t
/// @brief Half precision float
/// bit15: 1 bit SIGN +---+-----+------------+
/// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM|
/// bit0-9: 10bit MAN +---+-----+------------+
using fp16_t = class TagFp16 final {
public:
uint16_t val;
public:
/// @ingroup fp16_t constructor
/// @brief Constructor without any param(default constructor)
TagFp16(void) : TagFp16(0x0U) {}
/// @ingroup fp16_t constructor
/// @brief Constructor with an uint16_t value
TagFp16(const uint16_t ui_val) : val(ui_val) {}
/// @ingroup fp16_t constructor
/// @brief Constructor with a fp16_t object(copy constructor)
TagFp16(const TagFp16 &fp) = default;
/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be added
/// @brief Override addition operator to performing fp16_t addition
/// @return Return fp16_t result of adding this and fp
TagFp16 operator+(const TagFp16 fp) const;
/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be subtracted
/// @brief Override addition operator to performing fp16_t subtraction
/// @return Return fp16_t result of subtraction fp from this
TagFp16 operator-(const TagFp16 fp) const;
/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be multiplied
/// @brief Override multiplication operator to performing fp16_t multiplication
/// @return Return fp16_t result of multiplying this and fp
TagFp16 operator*(const TagFp16 fp) const;
/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t if-equal comparison
/// @return Return boolean result of if-equal comparison of this and fp.
friend bool operator==(const TagFp16 lhs, const TagFp16 rhs) noexcept;
/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t greater-than comparison
/// @return Return boolean result of greater-than comparison of this and fp.
friend bool operator>(const TagFp16 lhs, const TagFp16 rhs) noexcept;
/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t greater-equal comparison
/// @return Return boolean result of greater-equal comparison of this and fp.
friend bool operator>=(const TagFp16 lhs, const TagFp16 rhs) noexcept;
/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t less-equal comparison
/// @return Return boolean result of less-equal comparison of this and fp.
friend bool operator<=(const TagFp16 lhs, const TagFp16 rhs) noexcept;
/// @ingroup fp16_t math evaluation operator
/// @param [in] fp fp16_t object to be copy to fp16_t
/// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t
/// @return Return fp16_t result from fp
TagFp16 &operator=(const TagFp16 &fp) & = default;
/// @ingroup fp16_t math evaluation operator
/// @param [in] f_val float object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert float to fp16_t
/// @return Return fp16_t result from f_val
TagFp16 &operator=(const float32_t f_val) &;
/// @ingroup fp16_t math evaluation operator
/// @param [in] d_val double object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert double to fp16_t
/// @return Return fp16_t result from d_val
TagFp16 &operator=(const float64_t d_val) &;
/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val int32_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert int32_t to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int32_t i_val) &;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to float/fp32
/// @return Return float/fp32 value of fp16_t
explicit operator float32_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to double/fp64
/// @return Return double/fp64 value of fp16_t
explicit operator float64_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int8_t
/// @return Return int8_t value of fp16_t
explicit operator int8_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint8_t
/// @return Return uint8_t value of fp16_t
explicit operator uint8_t() const;
/// @ingroup fp16_t conversion
/// @brief Override convert operator to convert fp16_t to int16_t
/// @return Return int16_t value of fp16_t
explicit operator int16_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint16_t
/// @return Return uint16_t value of fp16_t
explicit operator uint16_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int32_t
/// @return Return int32_t value of fp16_t
explicit operator int32_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint32_t
/// @return Return uint32_t value of fp16_t
explicit operator uint32_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int64_t
/// @return Return int64_t value of fp16_t
explicit operator int64_t() const;
/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint64_t
/// @return Return uint64_t value of fp16_t
explicit operator uint64_t() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to float/fp32
/// @return Return float/fp32 value of fp16_t
float32_t ToFloat() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to double/fp64
/// @return Return double/fp64 value of fp16_t
float64_t ToDouble() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to int8_t
/// @return Return int8_t value of fp16_t
int8_t ToInt8() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint8_t
/// @return Return uint8_t value of fp16_t
uint8_t ToUInt8() const;
/// @ingroup fp16_t conversion
/// @brief Convert fp16_t to int16_t
/// @return Return int16_t value of fp16_t
int16_t ToInt16() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint16_t
/// @return Return uint16_t value of fp16_t
uint16_t ToUInt16() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to int32_t
/// @return Return int32_t value of fp16_t
int32_t ToInt32() const;
/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint32_t
/// @return Return uint32_t value of fp16_t
uint32_t ToUInt32() const;
};
static_assert(sizeof(fp16_t) == sizeof(uint16_t), "sizeof fp16_t must be 2");
/// @ingroup fp16_t public method
/// @param [in] val signature is negative
/// @param [in|out] s sign of fp16_t object
/// @param [in|out] e exponent of fp16_t object
/// @param [in|out] m mantissa of fp16_t object
/// @brief Extract the sign, exponent and mantissa of a fp16_t object
void ExtractFp16(const uint16_t val, uint16_t &s, int16_t &e, uint16_t &m);
/// @ingroup fp16_t public method
/// @param [in] negative sign is negative
/// @param [in|out] man mantissa to be reverse
/// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement)
/// @return Return complement of man
template <typename T>
void ReverseMan(const bool negative, T &man) {
if (negative) {
man = (~(man)) + 1U;
}
}
/// @ingroup fp16_t public method
/// @param [in] e_a exponent of one fp16_t/float number
/// @param [in] m_a mantissa of one fp16_t/float number
/// @param [in] e_b exponent of another fp16_t/float number
/// @param [in] m_b mantissa of another fp16_t/float number
/// @brief choose mantissa to be shift right whoes exponent is less than another one
/// @return Return mantissawhoes exponent is less than another one
template <typename T>
auto MinMan(const int16_t e_a, T m_a, const int16_t e_b, T m_b) -> T {
return (e_a > e_b) ? m_b : m_a;
}
/// @ingroup fp16_t public method
/// @param [in] man mantissa to be operate
/// @param [in] shift right shift bits
/// @brief right shift a mantissa
/// @return Return right-shift mantissa
template <typename T>
auto RightShift(T man, const int16_t shift) -> T {
constexpr uint64_t bits = static_cast<uint64_t>(sizeof(T) * 8U); // one byte have 8 bits
constexpr T mask = static_cast<T>(1U) << (bits - 1U);
int32_t loop_cnt = static_cast<int32_t>(shift);
while (loop_cnt-- != 0) {
man = ((man & mask) | (man >> 1U));
}
return man;
}
/// @ingroup fp16_t public method
/// @param [in] e_a exponent of one temp fp16_t number
/// @param [in] m_a mantissa of one temp fp16_t number
/// @param [in] e_b exponent of another temp fp16_t number
/// @param [in] m_b mantissa of another temp fp16_t number
/// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t
/// @return Return mantissa sum
template <typename T>
auto GetManSum(const int16_t e_a, const T &m_a, const int16_t e_b, const T &m_b) -> T {
T sum = 0U;
if (e_a != e_b) {
T m_tmp = 0U;
const int16_t e_tmp = static_cast<int16_t>(std::abs(static_cast<int32_t>(e_a - e_b)));
if (e_a > e_b) {
m_tmp = m_b;
m_tmp = RightShift(m_tmp, e_tmp);
sum = m_a + m_tmp;
} else {
m_tmp = m_a;
m_tmp = RightShift(m_tmp, e_tmp);
sum = m_tmp + m_b;
}
} else {
sum = m_a + m_b;
}
return sum;
}
/// @ingroup fp16_t public method
/// @param [in] bit0 whether the last preserved bit is 1 before round
/// @param [in] bit1 whether the abbreviation's highest bit is 1
/// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0
/// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t
/// @param [in] shift abbreviation bits
/// @brief Round fp16_t or float mantissa to nearest value
/// @return Returns true if round 1,otherwise false;
template <typename T>
auto ManRoundToNearest(const bool bit0, const bool bit1, const bool bitLeft, T man, const uint16_t shift = 0U) -> T {
const uint32_t mark = (bit1 && (bitLeft || bit0)) ? 1U : 0U;
man = static_cast<uint32_t>(man >> shift) + mark;
return man;
}
/// @ingroup fp16_t public method
/// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t
/// @brief Get bit length of a uint32_t number
/// @return Return bit length of man
template <typename T>
int16_t GetManBitLength(T man) {
int16_t len = 0;
while (man != 0U) {
man >>= 1U;
len++;
}
return len;
}
} // namespace ge
#endif // GE_COMMON_FP16_T_H_