#ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
#include "sqrt_80_bit_long_double.h"
#include "src/__support/CPP/bit.h"
#include "src/__support/CPP/type_traits.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/dyadic_float.h"
#include "src/__support/common.h"
#include "src/__support/macros/config.h"
#include "src/__support/uint128.h"
#include "hdr/fenv_macros.h"
namespace LIBC_NAMESPACE_DECL {
namespace fputil {
namespace internal {
template <typename T> struct SpecialLongDouble {
static constexpr bool VALUE = false;
};
#if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
template <> struct SpecialLongDouble<long double> {
static constexpr bool VALUE = true;
};
#endif
template <typename T>
LIBC_INLINE void normalize(int &exponent,
typename FPBits<T>::StorageType &mantissa) {
const int shift =
cpp::countl_zero(mantissa) -
(8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN);
exponent -= shift;
mantissa <<= shift;
}
#ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
template <>
LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) {
normalize<double>(exponent, mantissa);
}
#elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
template <>
LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
const int shift =
hi_bits ? (cpp::countl_zero(hi_bits) - 15)
: (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49);
exponent -= shift;
mantissa <<= shift;
}
#endif
}
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
sqrt(InType x) {
if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
internal::SpecialLongDouble<InType>::VALUE) {
return x86::sqrt(x);
} else {
using OutFPBits = FPBits<OutType>;
using InFPBits = FPBits<InType>;
using InStorageType = typename InFPBits::StorageType;
using DyadicFloat =
DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN))>;
constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
InFPBits bits(x);
if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
return static_cast<OutType>(x);
} else if (bits.is_neg()) {
return FLT_NAN;
} else {
int x_exp = bits.get_exponent();
InStorageType x_mant = bits.get_mantissa();
if (bits.is_subnormal()) {
++x_exp;
internal::normalize<InType>(x_exp, x_mant);
} else {
x_mant |= ONE;
}
if (x_exp & 1) {
--x_exp;
x_mant <<= 1;
}
InStorageType y = ONE;
InStorageType r = x_mant - ONE;
for (InStorageType current_bit = ONE >> 1; current_bit;
current_bit >>= 1) {
r <<= 1;
InStorageType tmp = (y << 1) + current_bit;
if (r >= tmp) {
r -= tmp;
y += current_bit;
}
}
r <<= 2;
y <<= 2;
InStorageType tmp = y + 1;
if (r >= tmp) {
r -= tmp;
y |= 2;
}
y |= static_cast<unsigned int>(r != 0);
DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y);
return yd.template as<OutType, true>();
}
}
}
}
}
#endif