#ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
#include "src/__support/CPP/bit.h"
#include "src/__support/CPP/limits.h"
#include "src/__support/CPP/type_traits.h"
#include "src/__support/FPUtil/BasicOperations.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/dyadic_float.h"
#include "src/__support/FPUtil/rounding_mode.h"
#include "src/__support/big_int.h"
#include "src/__support/macros/attributes.h"
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h"
#include "hdr/fenv_macros.h"
namespace LIBC_NAMESPACE_DECL {
namespace fputil {
namespace generic {
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
fma(InType x, InType y, InType z);
template <> LIBC_INLINE float fma<float>(float x, float y, float z) {
double prod = static_cast<double>(x) * static_cast<double>(y);
double z_d = static_cast<double>(z);
double sum = prod + z_d;
fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
if (!(bit_sum.is_inf_or_nan() || bit_sum.is_zero())) {
fputil::FPBits<double> t(
(bit_prod.get_biased_exponent() >= bitz.get_biased_exponent())
? ((bit_sum.get_val() - bit_prod.get_val()) - bitz.get_val())
: ((bit_sum.get_val() - bitz.get_val()) - bit_prod.get_val()));
if (!t.is_zero() && ((bit_sum.get_mantissa() & 0xfff'ffffULL) == 0)) {
if (bit_sum.sign() != t.sign())
bit_sum.set_mantissa(bit_sum.get_mantissa() + 1);
else if (bit_sum.get_mantissa())
bit_sum.set_mantissa(bit_sum.get_mantissa() - 1);
}
}
return static_cast<float>(bit_sum.get_val());
}
namespace internal {
template <typename T>
LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool>
shift_mantissa(int shift_length, T &mant) {
if (shift_length >= cpp::numeric_limits<T>::digits) {
mant = 0;
return true;
}
T mask = (T(1) << shift_length) - 1;
bool sticky_bits = (mant & mask) != 0;
mant >>= shift_length;
return sticky_bits;
}
}
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
fma(InType x, InType y, InType z) {
using OutFPBits = FPBits<OutType>;
using OutStorageType = typename OutFPBits::StorageType;
using InFPBits = FPBits<InType>;
using InStorageType = typename InFPBits::StorageType;
constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1;
constexpr size_t PROD_LEN = 2 * IN_EXPLICIT_MANT_LEN;
constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1);
using TmpResultType = UInt<TMP_RESULT_LEN>;
using DyadicFloat = DyadicFloat<TMP_RESULT_LEN>;
InFPBits x_bits(x), y_bits(y), z_bits(z);
if (LIBC_UNLIKELY(x_bits.is_nan() || y_bits.is_nan() || z_bits.is_nan())) {
if (x_bits.is_nan() || y_bits.is_nan()) {
if (x_bits.is_signaling_nan() || y_bits.is_signaling_nan() ||
z_bits.is_signaling_nan())
raise_except_if_required(FE_INVALID);
if (x_bits.is_quiet_nan()) {
InStorageType x_payload = x_bits.get_mantissa();
x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
return OutFPBits::quiet_nan(x_bits.sign(),
static_cast<OutStorageType>(x_payload))
.get_val();
}
if (y_bits.is_quiet_nan()) {
InStorageType y_payload = y_bits.get_mantissa();
y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
return OutFPBits::quiet_nan(y_bits.sign(),
static_cast<OutStorageType>(y_payload))
.get_val();
}
if (z_bits.is_quiet_nan()) {
InStorageType z_payload = z_bits.get_mantissa();
z_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
return OutFPBits::quiet_nan(z_bits.sign(),
static_cast<OutStorageType>(z_payload))
.get_val();
}
return OutFPBits::quiet_nan().get_val();
}
}
if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0))
return static_cast<OutType>(x * y + z);
int x_exp = 0;
int y_exp = 0;
int z_exp = 0;
constexpr InStorageType IMPLICIT_MASK =
InFPBits::SIG_MASK - InFPBits::FRACTION_MASK;
constexpr InType DENORMAL_SCALING =
InFPBits::create_value(
Sign::POS, InFPBits::FRACTION_LEN + InFPBits::EXP_BIAS, IMPLICIT_MASK)
.get_val();
if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) {
x_exp -= InFPBits::FRACTION_LEN;
x *= DENORMAL_SCALING;
}
if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) {
y_exp -= InFPBits::FRACTION_LEN;
y *= DENORMAL_SCALING;
}
if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) {
z_exp -= InFPBits::FRACTION_LEN;
z *= DENORMAL_SCALING;
}
x_bits = InFPBits(x);
y_bits = InFPBits(y);
z_bits = InFPBits(z);
const Sign z_sign = z_bits.sign();
Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG;
x_exp += x_bits.get_biased_exponent();
y_exp += y_bits.get_biased_exponent();
z_exp += z_bits.get_biased_exponent();
if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT ||
y_exp == InFPBits::MAX_BIASED_EXPONENT ||
z_exp == InFPBits::MAX_BIASED_EXPONENT))
return static_cast<OutType>(x * y + z);
InStorageType x_mant = x_bits.get_explicit_mantissa();
InStorageType y_mant = y_bits.get_explicit_mantissa();
TmpResultType z_mant = z_bits.get_explicit_mantissa();
TmpResultType prod_mant = TmpResultType(x_mant) * y_mant;
int prod_lsb_exp =
x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN);
constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN;
z_mant <<= RESULT_MIN_LEN;
int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN);
bool sticky_bits = false;
bool z_shifted = false;
if (prod_lsb_exp < z_lsb_exp) {
sticky_bits = internal::shift_mantissa(z_lsb_exp - prod_lsb_exp, prod_mant);
prod_lsb_exp = z_lsb_exp;
} else if (z_lsb_exp < prod_lsb_exp) {
z_shifted = true;
sticky_bits = internal::shift_mantissa(prod_lsb_exp - z_lsb_exp, z_mant);
}
if (prod_sign == z_sign) {
prod_mant += z_mant;
} else {
if (prod_mant >= z_mant) {
if (z_shifted && sticky_bits) {
++z_mant;
}
prod_mant -= z_mant;
} else {
if (!z_shifted && sticky_bits) {
++prod_mant;
}
prod_mant = z_mant - prod_mant;
prod_sign = z_sign;
}
}
if (prod_mant == 0) {
if (quick_get_round() == FE_DOWNWARD)
prod_sign = Sign::NEG;
else
prod_sign = Sign::POS;
}
DyadicFloat result(prod_sign, prod_lsb_exp - InFPBits::EXP_BIAS, prod_mant);
result.mantissa |= static_cast<unsigned int>(sticky_bits);
return result.template as<OutType, true>();
}
}
}
}
#endif