#include "mlir/ExecutionEngine/Float16bits.h"
#ifdef MLIR_FLOAT16_DEFINE_FUNCTIONS
#include <cmath>
#include <cstring>
namespace {
union Float32Bits {
uint32_t u;
float f;
};
const uint32_t kF32MantiBits = 23;
const uint32_t kF32HalfMantiBitDiff = 13;
const uint32_t kF32HalfBitDiff = 16;
const Float32Bits kF32Magic = {113 << kF32MantiBits};
const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
uint16_t float2half(float floatValue) {
const Float32Bits inf = {255 << kF32MantiBits};
const Float32Bits f16max = {(127 + 16) << kF32MantiBits};
const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1)
<< kF32MantiBits};
uint32_t signMask = 0x80000000u;
uint16_t halfValue = static_cast<uint16_t>(0x0u);
Float32Bits f;
f.f = floatValue;
uint32_t sign = f.u & signMask;
f.u ^= sign;
if (f.u >= f16max.u) {
const uint32_t halfQnan = 0x7e00;
const uint32_t halfInf = 0x7c00;
halfValue = (f.u > inf.u) ? halfQnan : halfInf;
} else {
if (f.u < kF32Magic.u) {
f.f += denormMagic.f;
halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
} else {
uint32_t mantOdd =
(f.u >> kF32HalfMantiBitDiff) & 1;
f.u += 0xc8000fffU;
f.u += mantOdd;
halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
}
}
halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
return halfValue;
}
float half2float(uint16_t halfValue) {
const uint32_t shiftedExp =
0x7c00 << kF32HalfMantiBitDiff;
Float32Bits f = {
static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
const uint32_t exp = shiftedExp & f.u;
f.u += kF32HalfExpAdjust;
if (exp == shiftedExp) {
f.u += kF32HalfExpAdjust;
} else if (exp == 0) {
f.u += 1 << kF32MantiBits;
f.f -= kF32Magic.f;
}
f.u |= (halfValue & 0x8000) << kF32HalfBitDiff;
return f.f;
}
const uint32_t kF32BfMantiBitDiff = 16;
uint16_t float2bfloat(float floatValue) {
if (std::isnan(floatValue))
return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0;
Float32Bits floatBits;
floatBits.f = floatValue;
uint16_t bfloatBits;
uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
uint32_t roundingBias = 0x7fff + lsb;
floatBits.u += roundingBias;
bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
return bfloatBits;
}
float bfloat2float(uint16_t bfloatBits) {
Float32Bits floatBits;
floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
return floatBits.f;
}
}
f16::f16(float f) : bits(float2half(f)) {}
bf16::bf16(float f) : bits(float2bfloat(f)) {}
std::ostream &operator<<(std::ostream &os, const f16 &f) {
os << half2float(f.bits);
return os;
}
std::ostream &operator<<(std::ostream &os, const bf16 &d) {
os << bfloat2float(d.bits);
return os;
}
bool operator==(const f16 &f1, const f16 &f2) { return f1.bits == f2.bits; }
bool operator==(const bf16 &f1, const bf16 &f2) { return f1.bits == f2.bits; }
#define ATTR_WEAK
#ifdef __has_attribute
#if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \
!defined(_WIN32)
#undef ATTR_WEAK
#define ATTR_WEAK __attribute__((__weak__))
#endif
#endif
#if defined(__x86_64__) || defined(_M_X64)
using BF16ABIType = float;
#else
using BF16ABIType = uint16_t;
#endif
extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) {
uint16_t bf = float2bfloat(f);
BF16ABIType ret = 0;
std::memcpy(&ret, &bf, sizeof(bf));
return ret;
}
extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
return __truncsfbf2(static_cast<float>(d));
}
extern "C" void printF16(uint16_t bits) {
f16 f;
std::memcpy(&f, &bits, sizeof(f16));
std::cout << f;
}
extern "C" void printBF16(uint16_t bits) {
bf16 f;
std::memcpy(&f, &bits, sizeof(bf16));
std::cout << f;
}
#endif