#ifndef OMNI_BASE_H_
#define OMNI_BASE_H_
#include <stddef.h>
#include <stdint.h>
#if !defined(OMNI_NO_LIBCXX)
#include <ostream>
#endif
#if !defined(OMNI_NO_LIBCXX)
#ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS
#endif
#include <inttypes.h>
#endif
#define OMNI_RESTRICT __restrict__
#define OMNI_INLINE inline __attribute__((always_inline))
#define OMNI_FLATTEN __attribute__((flatten))
#define OMNI_UNLIKELY(expr) __builtin_expect(!!(expr), 0)
#define OMNI_PRAGMA(tokens) _Pragma(#tokens)
#define OMNI_DIAGNOSTICS(tokens) OMNI_PRAGMA(GCC diagnostic tokens)
#define OMNI_DIAGNOSTICS_OFF(msc, gcc) OMNI_DIAGNOSTICS(gcc)
#define OMNI_MAYBE_UNUSED __attribute__((unused))
namespace simd {
#define OMNI_ASSUME_ALIGNED(ptr, align) (ptr)
#define OMNI_RCAST_ALIGNED(type, ptr) reinterpret_cast<type>(OMNI_ASSUME_ALIGNED((ptr), alignof(RemovePtr<type>)))
#define OMNI_PUSH_ATTRIBUTES(targets_str) OMNI_PRAGMA(GCC push_options) OMNI_PRAGMA(GCC target targets_str)
#define OMNI_POP_ATTRIBUTES OMNI_PRAGMA(GCC pop_options)
#define OMNI_API static OMNI_INLINE OMNI_FLATTEN OMNI_MAYBE_UNUSED
#define OMNI_CONCAT_IMPL(a, b) a##b
#define OMNI_CONCAT(a, b) OMNI_CONCAT_IMPL(a, b)
#define OMNI_MIN(a, b) ((a) < (b) ? (a) : (b))
#define OMNI_MAX(a, b) ((a) > (b) ? (a) : (b))
#define OMNI_ABORT(format, ...) 1
#define OMNI_ASSERT_M(condition, msg) \
do { \
if (!(condition)) { \
OMNI_ABORT("Assert %s: %s", #condition, msg); \
} \
} while (0)
#define OMNI_ASSERT(condition) OMNI_ASSERT_M(condition, "")
template <size_t kBytes, typename From, typename To>
OMNI_API void CopyBytes(const From *OMNI_RESTRICT from, To *OMNI_RESTRICT to)
{
__builtin_memcpy(to, from, kBytes);
}
OMNI_API void CopyBytes(const void *OMNI_RESTRICT from, void *OMNI_RESTRICT to, size_t num_of_bytes_to_copy)
{
__builtin_memcpy(to, from, num_of_bytes_to_copy);
}
template <typename From, typename To> OMNI_API void CopySameSize(const From *OMNI_RESTRICT from, To *OMNI_RESTRICT to)
{
static_assert(sizeof(From) == sizeof(To), "");
CopyBytes<sizeof(From)>(from, to);
}
static constexpr OMNI_MAYBE_UNUSED size_t kMaxVectorSize = 16;
#define OMNI_ALIGN_MAX alignas(16)
struct float16_t;
struct bfloat16_t;
using float32_t = float;
using float64_t = double;
#pragma pack(push, 1)
struct alignas(16) uint128_t {
uint64_t lo;
uint64_t hi;
};
struct alignas(16) K64V64 {
uint64_t value;
uint64_t key;
};
struct alignas(8) K32V32 {
uint32_t value;
uint32_t key;
};
#pragma pack(pop)
static inline OMNI_MAYBE_UNUSED bool operator < (const uint128_t &a, const uint128_t &b)
{
return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi;
}
static inline OMNI_MAYBE_UNUSED bool operator > (const uint128_t &a, const uint128_t &b)
{
return b < a;
}
static inline OMNI_MAYBE_UNUSED bool operator == (const uint128_t &a, const uint128_t &b)
{
return a.lo == b.lo && a.hi == b.hi;
}
#if !defined(OMNI_NO_LIBCXX)
static inline OMNI_MAYBE_UNUSED std::ostream &operator << (std::ostream &os, const uint128_t &n)
{
return os << "[hi=" << n.hi << ",lo=" << n.lo << "]";
}
#endif
static inline OMNI_MAYBE_UNUSED bool operator < (const K64V64 &a, const K64V64 &b)
{
return a.key < b.key;
}
static inline OMNI_MAYBE_UNUSED bool operator > (const K64V64 &a, const K64V64 &b)
{
return b < a;
}
static inline OMNI_MAYBE_UNUSED bool operator == (const K64V64 &a, const K64V64 &b)
{
return a.key == b.key;
}
#if !defined(OMNI_NO_LIBCXX)
static inline OMNI_MAYBE_UNUSED std::ostream &operator << (std::ostream &os, const K64V64 &n)
{
return os << "[k=" << n.key << ",v=" << n.value << "]";
}
#endif
static inline OMNI_MAYBE_UNUSED bool operator < (const K32V32 &a, const K32V32 &b)
{
return a.key < b.key;
}
static inline OMNI_MAYBE_UNUSED bool operator > (const K32V32 &a, const K32V32 &b)
{
return b < a;
}
static inline OMNI_MAYBE_UNUSED bool operator == (const K32V32 &a, const K32V32 &b)
{
return a.key == b.key;
}
static inline OMNI_MAYBE_UNUSED std::ostream &operator << (std::ostream &os, const K32V32 &n)
{
return os << "[k=" << n.key << ",v=" << n.value << "]";
}
template <bool Condition> struct EnableIfT {};
template <> struct EnableIfT<true> {
using type = void;
};
template <bool Condition> using EnableIf = typename EnableIfT<Condition>::type;
template <typename T, typename U> struct IsSameT {
enum {
value = 0
};
};
template <typename T> struct IsSameT<T, T> {
enum {
value = 1
};
};
template <typename T, typename U> OMNI_API constexpr bool IsSame()
{
return IsSameT<T, U>::value;
}
template <typename T, typename U1, typename U2> OMNI_API constexpr bool IsSameEither()
{
return IsSameT<T, U1>::value || IsSameT<T, U2>::value;
}
template <bool Condition, typename Then, typename Else> struct IfT {
using type = Then;
};
template <class Then, class Else> struct IfT<false, Then, Else> {
using type = Else;
};
template <bool Condition, typename Then, typename Else> using If = typename IfT<Condition, Then, Else>::type;
template <typename T> struct IsConstT {
enum {
value = 0
};
};
template <typename T> struct IsConstT<const T> {
enum {
value = 1
};
};
template <typename T> OMNI_API constexpr bool IsConst()
{
return IsConstT<T>::value;
}
template <class T> struct RemoveConstT {
using type = T;
};
template <class T> struct RemoveConstT<const T> {
using type = T;
};
template <class T> using RemoveConst = typename RemoveConstT<T>::type;
template <class T> struct RemoveVolatileT {
using type = T;
};
template <class T> struct RemoveVolatileT<volatile T> {
using type = T;
};
template <class T> using RemoveVolatile = typename RemoveVolatileT<T>::type;
template <class T> struct RemoveRefT {
using type = T;
};
template <class T> struct RemoveRefT<T &> {
using type = T;
};
template <class T> struct RemoveRefT<T &&> {
using type = T;
};
template <class T> using RemoveRef = typename RemoveRefT<T>::type;
template <class T> using RemoveCvRef = RemoveConst<RemoveVolatile<RemoveRef<T>>>;
template <class T> struct RemovePtrT {
using type = T;
};
template <class T> struct RemovePtrT<T *> {
using type = T;
};
template <class T> struct RemovePtrT<const T *> {
using type = T;
};
template <class T> struct RemovePtrT<volatile T *> {
using type = T;
};
template <class T> struct RemovePtrT<const volatile T *> {
using type = T;
};
template <class T> using RemovePtr = typename RemovePtrT<T>::type;
#define OMNI_IF_V_SIZE(T, kN, bytes) simd::EnableIf<kN * sizeof(T) == bytes> * = nullptr
#define OMNI_IF_V_SIZE_LE(T, kN, bytes) simd::EnableIf<kN * sizeof(T) <= bytes> * = nullptr
#define OMNI_IF_V_SIZE_GT(T, kN, bytes) simd::EnableIf<(kN * sizeof(T) > bytes)> * = nullptr
#define OMNI_IF_LANES(kN, lanes) simd::EnableIf<(kN == lanes)> * = nullptr
#define OMNI_IF_LANES_LE(kN, lanes) simd::EnableIf<(kN <= lanes)> * = nullptr
#define OMNI_IF_LANES_GT(kN, lanes) simd::EnableIf<(kN > lanes)> * = nullptr
#define OMNI_IF_UNSIGNED(T) simd::EnableIf<!simd::IsSigned<T>()> * = nullptr
#define OMNI_IF_NOT_UNSIGNED(T) simd::EnableIf<simd::IsSigned<T>()> * = nullptr
#define OMNI_IF_SIGNED(T) \
simd::EnableIf<simd::IsSigned<T>() && !simd::IsFloat<T>() && !simd::IsSpecialFloat<T>()> * = nullptr
#define OMNI_IF_FLOAT(T) simd::EnableIf<simd::IsFloat<T>()> * = nullptr
#define OMNI_IF_NOT_FLOAT(T) simd::EnableIf<!simd::IsFloat<T>()> * = nullptr
#define OMNI_IF_FLOAT3264(T) simd::EnableIf<simd::IsFloat3264<T>()> * = nullptr
#define OMNI_IF_NOT_FLOAT3264(T) simd::EnableIf<!simd::IsFloat3264<T>()> * = nullptr
#define OMNI_IF_SPECIAL_FLOAT(T) simd::EnableIf<simd::IsSpecialFloat<T>()> * = nullptr
#define OMNI_IF_NOT_SPECIAL_FLOAT(T) simd::EnableIf<!simd::IsSpecialFloat<T>()> * = nullptr
#define OMNI_IF_FLOAT_OR_SPECIAL(T) simd::EnableIf<simd::IsFloat<T>() || simd::IsSpecialFloat<T>()> * = nullptr
#define OMNI_IF_NOT_FLOAT_NOR_SPECIAL(T) simd::EnableIf<!simd::IsFloat<T>() && !simd::IsSpecialFloat<T>()> * = nullptr
#define OMNI_IF_INTEGER(T) simd::EnableIf<simd::IsInteger<T>()> * = nullptr
#define OMNI_IF_T_SIZE(T, bytes) simd::EnableIf<sizeof(T) == (bytes)> * = nullptr
#define OMNI_IF_NOT_T_SIZE(T, bytes) simd::EnableIf<sizeof(T) != (bytes)> * = nullptr
#define OMNI_IF_T_SIZE_ONE_OF(T, bit_array) simd::EnableIf<((size_t{ 1 } << sizeof(T)) & (bit_array)) != 0> * = nullptr
#define OMNI_IF_T_SIZE_LE(T, bytes) simd::EnableIf<(sizeof(T) <= (bytes))> * = nullptr
#define OMNI_IF_T_SIZE_GT(T, bytes) simd::EnableIf<(sizeof(T) > (bytes))> * = nullptr
#define OMNI_IF_SAME(T, expected) simd::EnableIf<simd::IsSame<simd::RemoveCvRef<T>, expected>()> * = nullptr
#define OMNI_IF_NOT_SAME(T, expected) simd::EnableIf<!simd::IsSame<simd::RemoveCvRef<T>, expected>()> * = nullptr
#define OMNI_IF_SAME2(T, expected1, expected2) \
simd::EnableIf<simd::IsSameEither<simd::RemoveCvRef<T>, expected1, expected2>()> * = nullptr
#define OMNI_IF_U8(T) OMNI_IF_SAME(T, uint8_t)
#define OMNI_IF_U16(T) OMNI_IF_SAME(T, uint16_t)
#define OMNI_IF_U32(T) OMNI_IF_SAME(T, uint32_t)
#define OMNI_IF_U64(T) OMNI_IF_SAME(T, uint64_t)
#define OMNI_IF_I8(T) OMNI_IF_SAME(T, int8_t)
#define OMNI_IF_I16(T) OMNI_IF_SAME(T, int16_t)
#define OMNI_IF_I32(T) OMNI_IF_SAME(T, int32_t)
#define OMNI_IF_I64(T) OMNI_IF_SAME(T, int64_t)
#define OMNI_IF_BF16(T) OMNI_IF_SAME(T, simd::bfloat16_t)
#define OMNI_IF_NOT_BF16(T) OMNI_IF_NOT_SAME(T, simd::bfloat16_t)
#define OMNI_IF_F16(T) OMNI_IF_SAME(T, simd::float16_t)
#define OMNI_IF_NOT_F16(T) OMNI_IF_NOT_SAME(T, simd::float16_t)
#define OMNI_IF_F32(T) OMNI_IF_SAME(T, float)
#define OMNI_IF_F64(T) OMNI_IF_SAME(T, double)
#define OMNI_IF_UI8(T) OMNI_IF_SAME2(T, uint8_t, int8_t)
#define OMNI_IF_UI16(T) OMNI_IF_SAME2(T, uint16_t, int16_t)
#define OMNI_IF_UI32(T) OMNI_IF_SAME2(T, uint32_t, int32_t)
#define OMNI_IF_UI64(T) OMNI_IF_SAME2(T, uint64_t, int64_t)
#define OMNI_IF_LANES_PER_BLOCK(T, N, LANES) \
simd::EnableIf<OMNI_MIN(sizeof(T) * N, 16) / sizeof(T) == (LANES)> * = nullptr
template <size_t N> struct SizeTag {};
template <class T> class DeclValT {
private:
template <class U, class URef = U &&> static URef TryAddRValRef(int);
template <class U, class Arg> static U TryAddRValRef(Arg);
public:
using type = decltype(TryAddRValRef<T>(0));
enum {
kDisableDeclValEvaluation = 1
};
};
template <class T> OMNI_API typename DeclValT<T>::type DeclVal() noexcept
{
static_assert(!DeclValT<T>::kDisableDeclValEvaluation, "DeclVal() cannot be used in an evaluated context");
}
template <class T> struct IsArrayT {
enum {
value = 0
};
};
template <class T> struct IsArrayT<T[]> {
enum {
value = 1
};
};
template <class T, size_t N> struct IsArrayT<T[N]> {
enum {
value = 1
};
};
template <class T> static constexpr bool IsArray()
{
return IsArrayT<T>::value;
}
template <class From, class To> class IsConvertibleT {
private:
template <class T> static simd::SizeTag<1> TestFuncWithToArg(T);
template <class T, class U>
static decltype(IsConvertibleT<T, U>::template TestFuncWithToArg<U>(DeclVal<T>())) TryConvTest(int);
template <class T, class U, class Arg> static simd::SizeTag<0> TryConvTest(Arg);
public:
enum {
value =
(IsSame<RemoveConst<RemoveVolatile<From>>, void>() && IsSame<RemoveConst<RemoveVolatile<To>>, void>()) ||
(!IsArray<To>() &&
(IsSame<To, decltype(DeclVal<To>())>() || !IsSame<const RemoveConst<To>, RemoveConst<To>>()) &&
IsSame<decltype(TryConvTest<From, To>(0)), simd::SizeTag<1>>())
};
};
template <class From, class To> OMNI_API constexpr bool IsConvertible()
{
return IsConvertibleT<From, To>::value;
}
template <class From, class To> class IsStaticCastableT {
private:
template <class T, class U, class = decltype(static_cast<U>(DeclVal<T>()))>
static simd::SizeTag<1> TryStaticCastTest(int);
template <class T, class U, class Arg> static simd::SizeTag<0> TryStaticCastTest(Arg);
public:
enum {
value = IsSame<decltype(TryStaticCastTest<From, To>(0)), simd::SizeTag<1>>()
};
};
template <class From, class To> static constexpr bool IsStaticCastable()
{
return IsStaticCastableT<From, To>::value;
}
#define OMNI_IF_CASTABLE(From, To) simd::EnableIf<IsStaticCastable<From, To>()> * = nullptr
#define OMNI_IF_OP_CASTABLE(op, T, Native) OMNI_IF_CASTABLE(decltype(DeclVal<Native>() op DeclVal<T>()), Native)
template <class T, class From> class IsAssignableT {
private:
template <class T1, class T2, class = decltype(DeclVal<T1>() = DeclVal<T2>())>
static simd::SizeTag<1> TryAssignTest(int);
template <class T1, class T2, class Arg> static simd::SizeTag<0> TryAssignTest(Arg);
public:
enum {
value = IsSame<decltype(TryAssignTest<T, From>(0)), simd::SizeTag<1>>()
};
};
template <class T, class From> static constexpr bool IsAssignable()
{
return IsAssignableT<T, From>::value;
}
#define OMNI_IF_ASSIGNABLE(T, From) simd::EnableIf<IsAssignable<T, From>()> * = nullptr
template <typename T> OMNI_API constexpr bool IsSpecialFloat()
{
return IsSameEither<RemoveCvRef<T>, simd::float16_t, simd::bfloat16_t>();
}
template <class T> OMNI_API constexpr bool IsIntegerLaneType()
{
return false;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<int8_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<uint8_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<int16_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<uint16_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<int32_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<uint32_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<int64_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsIntegerLaneType<uint64_t>()
{
return true;
}
namespace detail {
template <class T> static OMNI_INLINE constexpr bool IsNonCvInteger()
{
return IsIntegerLaneType<T>() || IsSame<T, wchar_t>() || IsSameEither<T, size_t, ptrdiff_t>() ||
IsSameEither<T, intptr_t, uintptr_t>();
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<bool>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<char>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<signed char>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<unsigned char>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<short>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<unsigned short>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<int>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<unsigned>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<long>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<unsigned long>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<long long>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<unsigned long long>()
{
return true;
}
#if defined(__cpp_char8_t) && __cpp_char8_t >= 201811L
template <> OMNI_INLINE constexpr bool IsNonCvInteger<char8_t>()
{
return true;
}
#endif
template <> OMNI_INLINE constexpr bool IsNonCvInteger<char16_t>()
{
return true;
}
template <> OMNI_INLINE constexpr bool IsNonCvInteger<char32_t>()
{
return true;
}
}
template <class T> OMNI_API constexpr bool IsInteger()
{
return detail::IsNonCvInteger<RemoveCvRef<T>>();
}
template <class To, class From> OMNI_API constexpr To BitCastScalar(const From &val)
{
To result;
CopySameSize(&val, &result);
return result;
}
#pragma pack(push, 1)
namespace detail {
template <class T, class TVal = RemoveCvRef<T>, bool = IsSpecialFloat<TVal>()>
struct SpecialFloatUnwrapArithOpOperandT {};
template <class T, class TVal> struct SpecialFloatUnwrapArithOpOperandT<T, TVal, false> {
using type = T;
};
template <class T> using SpecialFloatUnwrapArithOpOperand = typename SpecialFloatUnwrapArithOpOperandT<T>::type;
template <class T, class TVal = RemoveCvRef<T>> struct NativeSpecialFloatToWrapperT {
using type = T;
};
template <class T> using NativeSpecialFloatToWrapper = typename NativeSpecialFloatToWrapperT<T>::type;
}
struct alignas(2) bfloat16_t {
union {
uint16_t bits;
};
bfloat16_t() noexcept = default;
constexpr bfloat16_t(bfloat16_t &&) noexcept = default;
constexpr bfloat16_t(const bfloat16_t &) noexcept = default;
bfloat16_t &operator = (bfloat16_t &&arg) noexcept = default;
bfloat16_t &operator = (const bfloat16_t &arg) noexcept = default;
private:
struct BF16FromU16BitsTag {};
constexpr bfloat16_t(BF16FromU16BitsTag , uint16_t u16_bits) : bits(u16_bits) {}
public:
static constexpr bfloat16_t FromBits(uint16_t bits)
{
return bfloat16_t(BF16FromU16BitsTag(), bits);
}
};
static_assert(sizeof(simd::bfloat16_t) == 2, "Wrong size of bfloat16_t");
#pragma pack(pop)
OMNI_API constexpr float F32FromBF16(bfloat16_t bf)
{
return BitCastScalar<float>(static_cast<uint32_t>(static_cast<uint32_t>(BitCastScalar<uint16_t>(bf)) << 16));
}
namespace detail {
static OMNI_INLINE OMNI_MAYBE_UNUSED constexpr uint32_t F32BitsToBF16RoundIncr(const uint32_t f32_bits)
{
return static_cast<uint32_t>(((f32_bits & 0x7FFFFFFFu) < 0x7F800000u) ? (0x7FFFu + ((f32_bits >> 16) & 1u)) : 0u);
}
static OMNI_INLINE OMNI_MAYBE_UNUSED constexpr uint16_t F32BitsToBF16Bits(const uint32_t f32_bits)
{
return static_cast<uint16_t>(((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16) |
(static_cast<uint32_t>((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) << 6));
}
}
OMNI_API constexpr bfloat16_t BF16FromF32(float f)
{
return bfloat16_t::FromBits(detail::F32BitsToBF16Bits(BitCastScalar<uint32_t>(f)));
}
OMNI_API constexpr bfloat16_t BF16FromF64(double f64)
{
return BF16FromF32(static_cast<float>(
BitCastScalar<double>(static_cast<uint64_t>((BitCastScalar<uint64_t>(f64) & 0xFFFFFFC000000000ULL) |
((BitCastScalar<uint64_t>(f64) + 0x0000003FFFFFFFFFULL) & 0x0000004000000000ULL)))));
}
constexpr inline bool operator == (bfloat16_t lhs, bfloat16_t rhs) noexcept
{
return F32FromBF16(lhs) == F32FromBF16(rhs);
}
constexpr inline bool operator != (bfloat16_t lhs, bfloat16_t rhs) noexcept
{
return F32FromBF16(lhs) != F32FromBF16(rhs);
}
constexpr inline bool operator < (bfloat16_t lhs, bfloat16_t rhs) noexcept
{
return F32FromBF16(lhs) < F32FromBF16(rhs);
}
constexpr inline bool operator <= (bfloat16_t lhs, bfloat16_t rhs) noexcept
{
return F32FromBF16(lhs) <= F32FromBF16(rhs);
}
constexpr inline bool operator > (bfloat16_t lhs, bfloat16_t rhs) noexcept
{
return F32FromBF16(lhs) > F32FromBF16(rhs);
}
constexpr inline bool operator >= (bfloat16_t lhs, bfloat16_t rhs) noexcept
{
return F32FromBF16(lhs) >= F32FromBF16(rhs);
}
namespace detail {
template <typename T> struct Relations;
template <> struct Relations<uint8_t> {
using Unsigned = uint8_t;
using Signed = int8_t;
using Wide = uint16_t;
enum {
is_signed = 0,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<int8_t> {
using Unsigned = uint8_t;
using Signed = int8_t;
using Wide = int16_t;
enum {
is_signed = 1,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<uint16_t> {
using Unsigned = uint16_t;
using Signed = int16_t;
using Float = float16_t;
using Wide = uint32_t;
using Narrow = uint8_t;
enum {
is_signed = 0,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<int16_t> {
using Unsigned = uint16_t;
using Signed = int16_t;
using Float = float16_t;
using Wide = int32_t;
using Narrow = int8_t;
enum {
is_signed = 1,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<uint32_t> {
using Unsigned = uint32_t;
using Signed = int32_t;
using Float = float;
using Wide = uint64_t;
using Narrow = uint16_t;
enum {
is_signed = 0,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<int32_t> {
using Unsigned = uint32_t;
using Signed = int32_t;
using Float = float;
using Wide = int64_t;
using Narrow = int16_t;
enum {
is_signed = 1,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<uint64_t> {
using Unsigned = uint64_t;
using Signed = int64_t;
using Float = double;
using Wide = uint128_t;
using Narrow = uint32_t;
enum {
is_signed = 0,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<int64_t> {
using Unsigned = uint64_t;
using Signed = int64_t;
using Float = double;
using Narrow = int32_t;
enum {
is_signed = 1,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<uint128_t> {
using Unsigned = uint128_t;
using Narrow = uint64_t;
enum {
is_signed = 0,
is_float = 0,
is_bf16 = 0
};
};
template <> struct Relations<float16_t> {
using Unsigned = uint16_t;
using Signed = int16_t;
using Float = float16_t;
using Wide = float;
enum {
is_signed = 1,
is_float = 1,
is_bf16 = 0
};
};
template <> struct Relations<bfloat16_t> {
using Unsigned = uint16_t;
using Signed = int16_t;
using Wide = float;
enum {
is_signed = 1,
is_float = 1,
is_bf16 = 1
};
};
template <> struct Relations<float> {
using Unsigned = uint32_t;
using Signed = int32_t;
using Float = float;
using Wide = double;
using Narrow = float16_t;
enum {
is_signed = 1,
is_float = 1,
is_bf16 = 0
};
};
template <> struct Relations<double> {
using Unsigned = uint64_t;
using Signed = int64_t;
using Float = double;
using Narrow = float;
enum {
is_signed = 1,
is_float = 1,
is_bf16 = 0
};
};
template <size_t N> struct TypeFromSize;
template <> struct TypeFromSize<1> {
using Unsigned = uint8_t;
using Signed = int8_t;
};
template <> struct TypeFromSize<2> {
using Unsigned = uint16_t;
using Signed = int16_t;
using Float = float16_t;
};
template <> struct TypeFromSize<4> {
using Unsigned = uint32_t;
using Signed = int32_t;
using Float = float;
};
template <> struct TypeFromSize<8> {
using Unsigned = uint64_t;
using Signed = int64_t;
using Float = double;
};
template <> struct TypeFromSize<16> {
using Unsigned = uint128_t;
};
}
template <typename T> using MakeUnsigned = typename detail::Relations<T>::Unsigned;
template <typename T> using MakeSigned = typename detail::Relations<T>::Signed;
template <typename T> using MakeFloat = typename detail::Relations<T>::Float;
template <typename T> using MakeWide = typename detail::Relations<T>::Wide;
template <typename T> using MakeNarrow = typename detail::Relations<T>::Narrow;
template <size_t N> using UnsignedFromSize = typename detail::TypeFromSize<N>::Unsigned;
template <size_t N> using SignedFromSize = typename detail::TypeFromSize<N>::Signed;
template <size_t N> using FloatFromSize = typename detail::TypeFromSize<N>::Float;
using UnsignedTag = SizeTag<0>;
using SignedTag = SizeTag<0x100>;
using FloatTag = SizeTag<0x200>;
using SpecialTag = SizeTag<0x300>;
template <typename T, class R = detail::Relations<T>>
constexpr auto TypeTag() -> simd::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)>
{
return simd::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)>();
}
using NonFloatTag = SizeTag<0x400>;
template <typename T, class R = detail::Relations<T>>
constexpr auto IsFloatTag() -> simd::SizeTag<(R::is_float ? 0x200 : 0x400)>
{
return simd::SizeTag<(R::is_float ? 0x200 : 0x400)>();
}
template <typename T> OMNI_API constexpr bool IsFloat3264()
{
return IsSameEither<RemoveCvRef<T>, float, double>();
}
template <typename T> OMNI_API constexpr bool IsFloat()
{
return IsSame<RemoveCvRef<T>, float16_t>() || IsFloat3264<T>();
}
template <typename T> OMNI_API constexpr bool IsSigned()
{
return static_cast<T>(0) > static_cast<T>(-1);
}
template <> constexpr bool IsSigned<float16_t>()
{
return true;
}
template <> constexpr bool IsSigned<bfloat16_t>()
{
return true;
}
template <> constexpr bool IsSigned<simd::uint128_t>()
{
return false;
}
template <> constexpr bool IsSigned<simd::K64V64>()
{
return false;
}
template <> constexpr bool IsSigned<simd::K32V32>()
{
return false;
}
template <typename T, bool = IsInteger<T>() && !IsIntegerLaneType<T>()> struct MakeLaneTypeIfIntegerT {
using type = T;
};
template <typename T> struct MakeLaneTypeIfIntegerT<T, true> {
using type = simd::If<IsSigned<T>(), SignedFromSize<sizeof(T)>, UnsignedFromSize<sizeof(T)>>;
};
template <typename T> using MakeLaneTypeIfInteger = typename MakeLaneTypeIfIntegerT<T>::type;
template <typename T> OMNI_API constexpr T LimitsMax()
{
static_assert(IsInteger<T>(), "Only for integer types");
using TU = UnsignedFromSize<sizeof(T)>;
return static_cast<T>(IsSigned<T>() ? (static_cast<TU>(~TU(0)) >> 1) : static_cast<TU>(~TU(0)));
}
template <typename T> OMNI_API constexpr T LimitsMin()
{
static_assert(IsInteger<T>(), "Only for integer types");
return IsSigned<T>() ? static_cast<T>(-1) - LimitsMax<T>() : static_cast<T>(0);
}
template <typename T> OMNI_API constexpr T LowestValue()
{
return LimitsMin<T>();
}
template <> OMNI_INLINE constexpr bfloat16_t LowestValue<bfloat16_t>()
{
return bfloat16_t::FromBits(uint16_t{ 0xFF7Fu });
}
template <> OMNI_INLINE constexpr float LowestValue<float>()
{
return -3.402823466e+38F;
}
template <> OMNI_INLINE constexpr double LowestValue<double>()
{
return -1.7976931348623158e+308;
}
template <typename T> OMNI_API constexpr T HighestValue()
{
return LimitsMax<T>();
}
template <> OMNI_INLINE constexpr bfloat16_t HighestValue<bfloat16_t>()
{
return bfloat16_t::FromBits(uint16_t{ 0x7F7Fu });
}
template <> OMNI_INLINE constexpr float HighestValue<float>()
{
return 3.402823466e+38F;
}
template <> OMNI_INLINE constexpr double HighestValue<double>()
{
return 1.7976931348623158e+308;
}
template <typename T> OMNI_API constexpr T Epsilon()
{
return 1;
}
template <> OMNI_INLINE constexpr bfloat16_t Epsilon<bfloat16_t>()
{
return bfloat16_t::FromBits(uint16_t{ 0x3C00u });
}
template <> OMNI_INLINE constexpr float Epsilon<float>()
{
return 1.192092896e-7f;
}
template <> OMNI_INLINE constexpr double Epsilon<double>()
{
return 2.2204460492503131e-16;
}
template <typename T> constexpr int MantissaBits()
{
static_assert(sizeof(T) == 0, "Only instantiate the specializations");
return 0;
}
template <> constexpr int MantissaBits<bfloat16_t>()
{
return 7;
}
template <> constexpr int MantissaBits<float16_t>()
{
return 10;
}
template <> constexpr int MantissaBits<float>()
{
return 23;
}
template <> constexpr int MantissaBits<double>()
{
return 52;
}
template <typename T> constexpr MakeSigned<T> MaxExponentTimes2()
{
return -(MakeSigned<T>{ 1 } << (MantissaBits<T>() + 1));
}
template <typename T> constexpr MakeUnsigned<T> SignMask()
{
return MakeUnsigned<T>{ 1 } << (sizeof(T) * 8 - 1);
}
template <typename T> constexpr MakeUnsigned<T> ExponentMask()
{
return (~(MakeUnsigned<T>{ 1 } << MantissaBits<T>()) + 1) & static_cast<MakeUnsigned<T>>(~SignMask<T>());
}
template <typename T> constexpr MakeUnsigned<T> MantissaMask()
{
return (MakeUnsigned<T>{ 1 } << MantissaBits<T>()) - 1;
}
template <typename T> OMNI_INLINE constexpr T MantissaEnd()
{
static_assert(sizeof(T) == 0, "Only instantiate the specializations");
return 0;
}
template <> OMNI_INLINE constexpr bfloat16_t MantissaEnd<bfloat16_t>()
{
return bfloat16_t::FromBits(uint16_t{ 0x4300u });
}
template <> OMNI_INLINE constexpr float MantissaEnd<float>()
{
return 8388608.0f;
}
template <> OMNI_INLINE constexpr double MantissaEnd<double>()
{
return 4503599627370496.0;
}
template <typename T> constexpr int ExponentBits()
{
return 8 * sizeof(T) - 1 - MantissaBits<T>();
}
template <typename T> constexpr MakeSigned<T> MaxExponentField()
{
return (MakeSigned<T>{ 1 } << ExponentBits<T>()) - 1;
}
#define OMNI_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T2) \
template <typename T1, \
simd::EnableIf<simd::IsInteger<RemoveCvRef<T1>>() || simd::IsFloat3264<RemoveCvRef<T1>>()> * = nullptr, \
typename RawResultT = decltype(DeclVal<T1>() op DeclVal<T2::Native>()), \
typename ResultT = detail::NativeSpecialFloatToWrapper<RawResultT>, OMNI_IF_CASTABLE(RawResultT, ResultT)> \
static OMNI_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept \
{ \
return static_cast<ResultT>(a op b.native); \
}
#define OMNI_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(op, op_func, T1) \
OMNI_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T1) \
template <typename T2, \
simd::EnableIf<simd::IsInteger<RemoveCvRef<T2>>() || simd::IsFloat3264<RemoveCvRef<T2>>()> * = nullptr, \
typename RawResultT = decltype(DeclVal<T1::Native>() op DeclVal<T2>()), \
typename ResultT = detail::NativeSpecialFloatToWrapper<RawResultT>, OMNI_IF_CASTABLE(RawResultT, ResultT)> \
static OMNI_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept \
{ \
return static_cast<ResultT>(a.native op b); \
}
#undef OMNI_RHS_SPECIAL_FLOAT_ARITH_OP
#undef OMNI_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP
OMNI_API float F32FromBF16Mem(const void *ptr)
{
bfloat16_t bf;
CopyBytes<2>(OMNI_ASSUME_ALIGNED(ptr, 2), &bf);
return F32FromBF16(bf);
}
template <typename TTo, typename TFrom, OMNI_IF_NOT_SPECIAL_FLOAT(TTo), OMNI_IF_NOT_SPECIAL_FLOAT(TFrom),
OMNI_IF_NOT_SAME(TTo, TFrom)>
OMNI_API constexpr TTo ConvertScalarTo(const TFrom in)
{
return static_cast<TTo>(in);
}
template <typename TTo, typename TFrom, OMNI_IF_BF16(TTo), OMNI_IF_NOT_SPECIAL_FLOAT(TFrom),
OMNI_IF_NOT_SAME(TFrom, double)>
OMNI_API constexpr TTo ConvertScalarTo(const TFrom in)
{
return BF16FromF32(static_cast<float>(in));
}
template <typename TTo, OMNI_IF_BF16(TTo)> OMNI_API constexpr TTo ConvertScalarTo(const double in)
{
return BF16FromF64(in);
}
template <typename TTo, typename TFrom, OMNI_IF_F16(TFrom), OMNI_IF_NOT_SPECIAL_FLOAT(TTo)>
OMNI_API constexpr TTo ConvertScalarTo(const TFrom in)
{
return static_cast<TTo>(F32FromF16(in));
}
template <typename TTo, typename TFrom, OMNI_IF_BF16(TFrom), OMNI_IF_NOT_SPECIAL_FLOAT(TTo)>
OMNI_API constexpr TTo ConvertScalarTo(TFrom in)
{
return static_cast<TTo>(F32FromBF16(in));
}
template <typename TTo> OMNI_API constexpr TTo ConvertScalarTo(TTo in)
{
return in;
}
template <typename T1, typename T2> constexpr inline T1 DivCeil(T1 a, T2 b)
{
return (a + b - 1) / b;
}
constexpr inline size_t RoundUpTo(size_t what, size_t align)
{
return DivCeil(what, align) * align;
}
constexpr inline size_t RoundDownTo(size_t what, size_t align)
{
return what - (what % align);
}
namespace detail {
template <class T> static OMNI_INLINE constexpr T ScalarShr(simd::UnsignedTag , T val, int shift_amt)
{
return static_cast<T>(val >> shift_amt);
}
template <class T> static OMNI_INLINE constexpr T ScalarShr(simd::SignedTag , T val, int shift_amt)
{
using TU = MakeUnsigned<MakeLaneTypeIfInteger<T>>;
return static_cast<T>((val < 0) ? static_cast<TU>(~(static_cast<TU>(~static_cast<TU>(val)) >> shift_amt)) :
static_cast<TU>(static_cast<TU>(val) >> shift_amt));
}
}
template <class T, OMNI_IF_INTEGER(RemoveCvRef<T>)> OMNI_API constexpr RemoveCvRef<T> ScalarShr(T val, int shift_amt)
{
using NonCvRefT = RemoveCvRef<T>;
return detail::ScalarShr(simd::SizeTag<(
(IsSigned<NonCvRefT>() && (LimitsMin<NonCvRefT>() >> (sizeof(T) * 8 - 1)) != static_cast<NonCvRefT>(-1)) ?
0x100 :
0)>(),
static_cast<NonCvRefT>(val), shift_amt);
}
OMNI_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x)
{
return static_cast<size_t>(__builtin_ctz(x));
}
OMNI_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x)
{
return static_cast<size_t>(__builtin_ctzll(x));
}
OMNI_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
{
return static_cast<size_t>(__builtin_clz(x));
}
OMNI_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x)
{
return static_cast<size_t>(__builtin_clzll(x));
}
template <class T, OMNI_IF_INTEGER(RemoveCvRef<T>),
OMNI_IF_T_SIZE_ONE_OF(RemoveCvRef<T>, (1 << 1) | (1 << 2) | (1 << 4))>
OMNI_API size_t PopCount(T x)
{
uint32_t u32_x = static_cast<uint32_t>(static_cast<UnsignedFromSize<sizeof(RemoveCvRef<T>)>>(x));
return static_cast<size_t>(__builtin_popcountl(u32_x));
}
template <class T, OMNI_IF_INTEGER(RemoveCvRef<T>), OMNI_IF_T_SIZE(RemoveCvRef<T>, 8)> OMNI_API size_t PopCount(T x)
{
uint64_t u64_x = static_cast<uint64_t>(static_cast<UnsignedFromSize<sizeof(RemoveCvRef<T>)>>(x));
return static_cast<size_t>(__builtin_popcountll(u64_x));
}
template <typename TI> constexpr size_t FloorLog2(TI x)
{
return x == TI{ 1 } ? 0 : static_cast<size_t>(FloorLog2(static_cast<TI>(x >> 1)) + 1);
}
template <typename TI> constexpr size_t CeilLog2(TI x)
{
return x == TI{ 1 } ? 0 : static_cast<size_t>(FloorLog2(static_cast<TI>(x - 1)) + 1);
}
template <typename T, typename T2, OMNI_IF_FLOAT(T), OMNI_IF_NOT_SPECIAL_FLOAT(T)>
OMNI_INLINE constexpr T AddWithWraparound(T t, T2 increment)
{
return t + static_cast<T>(increment);
}
template <typename T, typename T2, OMNI_IF_SPECIAL_FLOAT(T)>
OMNI_INLINE constexpr T AddWithWraparound(T t, T2 increment)
{
return ConvertScalarTo<T>(ConvertScalarTo<float>(t) + ConvertScalarTo<float>(increment));
}
template <typename T, typename T2, OMNI_IF_NOT_FLOAT(T)> OMNI_INLINE constexpr T AddWithWraparound(T t, T2 n)
{
using TU = MakeUnsigned<T>;
return static_cast<T>(static_cast<TU>(
static_cast<unsigned long long>(static_cast<unsigned long long>(t) + static_cast<unsigned long long>(n)) &
uint64_t{ simd::LimitsMax<TU>() }));
}
OMNI_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t *OMNI_RESTRICT upper)
{
__uint128_t product = (__uint128_t)a * (__uint128_t)b;
*upper = (uint64_t)(product >> 64);
return (uint64_t)(product & 0xFFFFFFFFFFFFFFFFULL);
}
OMNI_API int64_t Mul128(int64_t a, int64_t b, int64_t *OMNI_RESTRICT upper)
{
__int128_t product = (__int128_t)a * (__int128_t)b;
*upper = (int64_t)(product >> 64);
return (int64_t)(product & 0xFFFFFFFFFFFFFFFFULL);
}
namespace detail {
template <typename T> static OMNI_INLINE constexpr T ScalarAbs(simd::FloatTag , T val)
{
using TU = MakeUnsigned<T>;
return BitCastScalar<T>(static_cast<TU>(BitCastScalar<TU>(val) & (~SignMask<T>())));
}
template <typename T> static OMNI_INLINE constexpr T ScalarAbs(simd::SpecialTag , T val)
{
return ScalarAbs(simd::FloatTag(), val);
}
template <typename T> static OMNI_INLINE constexpr T ScalarAbs(simd::SignedTag , T val)
{
using TU = MakeUnsigned<T>;
return (val < T{ 0 }) ? static_cast<T>(TU{ 0 } - static_cast<TU>(val)) : val;
}
template <typename T> static OMNI_INLINE constexpr T ScalarAbs(simd::UnsignedTag , T val)
{
return val;
}
}
template <typename T> OMNI_API constexpr RemoveCvRef<T> ScalarAbs(T val)
{
using TVal = MakeLaneTypeIfInteger<detail::NativeSpecialFloatToWrapper<RemoveCvRef<T>>>;
return detail::ScalarAbs(simd::TypeTag<TVal>(), static_cast<TVal>(val));
}
template <typename T> OMNI_API constexpr bool ScalarIsNaN(T val)
{
using TF = detail::NativeSpecialFloatToWrapper<RemoveCvRef<T>>;
using TU = MakeUnsigned<TF>;
return (BitCastScalar<TU>(ScalarAbs(val)) > ExponentMask<TF>());
}
template <typename T> OMNI_API constexpr bool ScalarIsInf(T val)
{
using TF = detail::NativeSpecialFloatToWrapper<RemoveCvRef<T>>;
using TU = MakeUnsigned<TF>;
return static_cast<TU>(BitCastScalar<TU>(static_cast<TF>(val)) << 1) == static_cast<TU>(MaxExponentTimes2<TF>());
}
namespace detail {
template <typename T> static OMNI_INLINE constexpr bool ScalarIsFinite(simd::FloatTag , T val)
{
using TU = MakeUnsigned<T>;
return (BitCastScalar<TU>(simd::ScalarAbs(val)) < ExponentMask<T>());
}
template <typename T> static OMNI_INLINE constexpr bool ScalarIsFinite(simd::NonFloatTag , T )
{
return true;
}
}
template <typename T> OMNI_API constexpr bool ScalarIsFinite(T val)
{
using TVal = MakeLaneTypeIfInteger<detail::NativeSpecialFloatToWrapper<RemoveCvRef<T>>>;
return detail::ScalarIsFinite(simd::IsFloatTag<TVal>(), static_cast<TVal>(val));
}
template <typename T> OMNI_API constexpr RemoveCvRef<T> ScalarCopySign(T magn, T sign)
{
using TF = RemoveCvRef<detail::NativeSpecialFloatToWrapper<RemoveCvRef<T>>>;
using TU = MakeUnsigned<TF>;
return BitCastScalar<TF>(static_cast<TU>((BitCastScalar<TU>(static_cast<TF>(magn)) & (~SignMask<TF>())) |
(BitCastScalar<TU>(static_cast<TF>(sign)) & SignMask<TF>())));
}
template <typename T> OMNI_API constexpr bool ScalarSignBit(T val)
{
using TVal = MakeLaneTypeIfInteger<detail::NativeSpecialFloatToWrapper<RemoveCvRef<T>>>;
using TU = MakeUnsigned<TVal>;
return ((BitCastScalar<TU>(static_cast<TVal>(val)) & SignMask<TVal>()) != 0);
}
template <class T> OMNI_API void PreventElision(T &&output)
{
asm volatile("" : "+r"(output) : : "memory");
}
#define OMNI_IS_LITTLE_ENDIAN 1
#define OMNI_ARCH_ARM_A64 1
#define OMNI_IF_CONSTEXPR if constexpr
}
#endif