* Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
*/
#ifndef OMNI_SIMD_REDUCE_FUNC_H
#define OMNI_SIMD_REDUCE_FUNC_H
#include <limits>
#include <cstring>
#include "simd/simd.h"
#include "libboundscheck/include/securec.h"
namespace simd {
enum class ReduceFunc {
Sum,
Max,
Min,
};
template <typename D, typename IN, typename OUT>
OMNI_INLINE VFromD<D> LoadFromDiffType(D d, const IN *OMNI_RESTRICT inPtr)
{
auto rowCount = Lanes(d);
if constexpr (std::is_same_v<IN, OUT>) {
return LoadU(d, inPtr);
} else {
OUT outPtr[rowCount];
for (int i = 0; i < rowCount; i++) {
outPtr[i] = static_cast<OUT>(inPtr[i]);
}
return LoadU(d, outPtr);
}
}
template <typename D, typename IN, typename OUT>
OMNI_INLINE VFromD<D> LoadFromDic(D d, const IN *OMNI_RESTRICT inPtr, const int32_t *OMNI_RESTRICT indexMap)
{
auto rowCount = Lanes(d);
OUT outPtr[rowCount];
for (int i = 0; i < rowCount; i++) {
outPtr[i] = static_cast<OUT>(inPtr[indexMap[i]]);
}
return LoadU(d, outPtr);
}
#define INIT(Rf) \
auto Init = [&]() { \
if constexpr (Rf == ReduceFunc::Sum) { \
limitValue = 0; \
reduceVec = Set(d, limitValue); \
} else if constexpr (Rf == ReduceFunc::Max) { \
if constexpr (std::is_same_v<double, OUT>) { \
limitValue = std::numeric_limits<double>::lowest(); \
} else { \
limitValue = std::numeric_limits<OUT>::min(); \
} \
reduceVec = Set(d, limitValue); \
func = Max; \
reduce = ReduceMax; \
} else if constexpr (Rf == ReduceFunc::Min) { \
limitValue = std::numeric_limits<OUT>::max(); \
reduceVec = Set(d, limitValue); \
func = Min; \
reduce = ReduceMin; \
} else { \
throw std::runtime_error("Do not support reduce func"); \
} \
}
template <typename IN, typename OUT, ReduceFunc reduceFunc> OUT Reduce(const IN *OMNI_RESTRICT array, size_t size)
{
using D = ScalableTag<OUT>;
D d;
auto reduceVec = Zero(d);
OUT limitValue;
using Func = VFromD<D> (*)(VFromD<D>, VFromD<D>);
using Reduce = TFromD<D> (*)(D, VFromD<D>);
Func func = Add;
Reduce reduce = ReduceSum;
INIT(reduceFunc);
Init();
size_t i = 0;
const auto lanes = Lanes(d);
auto k = size - (size % lanes);
for (; i < k; i += lanes) {
auto vec = LoadFromDiffType<decltype(d), IN, OUT>(d, array + i);
reduceVec = func(reduceVec, vec);
}
OUT buf[lanes];
std::fill(buf, buf + lanes, limitValue);
if constexpr (std::is_same_v<IN, OUT>) {
memcpy_s(buf, (size - k) * sizeof(OUT), array + i, (size - k) * sizeof(OUT));
} else {
for (; i < size; i++) {
buf[size - i] = static_cast<OUT>(array[i]);
}
}
auto vec = LoadU(d, buf);
reduceVec = func(reduceVec, vec);
return reduce(d, reduceVec);
}
template <typename IN, typename OUT, ReduceFunc reduceFunc>
OUT ReduceWithNulls(const IN *OMNI_RESTRICT array, const uint8_t *OMNI_RESTRICT nulls, size_t size)
{
using D = ScalableTag<OUT>;
D d;
auto reduceVec = Zero(d);
CappedTag<uint8_t, D::kPrivateLanes> du;
OUT limitValue;
using Func = VFromD<D> (*)(VFromD<D>, VFromD<D>);
using Reduce = TFromD<D> (*)(D, VFromD<D>);
Func func = Add;
Reduce reduce = ReduceSum;
INIT(reduceFunc);
Init();
size_t i = 0;
const auto lanes = Lanes(d);
auto k = size - (size % lanes);
for (; i < k; i += lanes) {
auto nullVec = Zero(d);
if constexpr (std::is_same_v<OUT, double>) {
ScalableTag<uint64_t> dt;
nullVec = BitCast(d, PromoteTo(dt, LoadU(du, nulls + i)));
} else {
nullVec = PromoteTo(d, LoadU(du, nulls + i));
}
auto mask = Eq(nullVec, Zero(d));
auto vec = IfThenElse(mask, LoadFromDiffType<decltype(d), IN, OUT>(d, array + i), Set(d, limitValue));
reduceVec = func(reduceVec, vec);
}
OUT buf[lanes];
uint8_t nullBuf[lanes];
std::fill(buf, buf + lanes, limitValue);
std::fill(nullBuf, nullBuf + lanes, 0);
if constexpr (std::is_same_v<IN, OUT>) {
memcpy_s(buf, (size - k) * sizeof(OUT), array + i, (size - k) * sizeof(OUT));
memcpy_s(nullBuf, (size - k) * sizeof(uint8_t), nulls + i, (size - k) * sizeof(uint8_t));
} else {
for (size_t j = i; j < size; j++) {
buf[size - j] = static_cast<OUT>(array[j]);
nullBuf[size - j] = nulls[j];
}
}
auto nullVec = Zero(d);
if constexpr (std::is_same_v<OUT, double>) {
ScalableTag<uint64_t> dt;
nullVec = BitCast(d, PromoteTo(dt, LoadU(du, nullBuf)));
} else {
nullVec = PromoteTo(d, LoadU(du, nullBuf));
}
auto mask = Eq(nullVec, Zero(d));
auto vec = IfThenElse(mask, LoadU(d, buf), Set(d, limitValue));
reduceVec = func(reduceVec, vec);
return reduce(d, reduceVec);
}
template <typename IN, typename OUT, ReduceFunc reduceFunc>
OUT ReduceWithDic(const IN *OMNI_RESTRICT array, const int32_t *OMNI_RESTRICT indexMap, size_t size)
{
using D = ScalableTag<OUT>;
D d;
auto reduceVec = Zero(d);
OUT limitValue;
using Func = VFromD<D> (*)(VFromD<D>, VFromD<D>);
using Reduce = TFromD<D> (*)(D, VFromD<D>);
Func func = Add;
Reduce reduce = ReduceSum;
INIT(reduceFunc);
Init();
size_t i = 0;
const auto lanes = Lanes(d);
auto k = size - (size % lanes);
for (; i < k; i += lanes) {
auto vec = LoadFromDic<decltype(d), IN, OUT>(d, array, indexMap + i);
reduceVec = func(reduceVec, vec);
}
OUT buf[lanes];
std::fill(buf, buf + lanes, limitValue);
for (; i < size; i++) {
buf[size - i] = static_cast<OUT>(array[indexMap[i]]);
}
auto vec = LoadU(d, buf);
reduceVec = func(reduceVec, vec);
return reduce(d, reduceVec);
}
template <typename IN, typename OUT, ReduceFunc reduceFunc>
OUT ReduceWithDicAndNulls(const IN *OMNI_RESTRICT array, const int32_t *OMNI_RESTRICT indexMap,
const uint8_t *OMNI_RESTRICT nulls, size_t size)
{
using D = ScalableTag<OUT>;
D d;
auto reduceVec = Zero(d);
CappedTag<uint8_t, D::kPrivateLanes> du;
OUT limitValue;
using Func = VFromD<D> (*)(VFromD<D>, VFromD<D>);
using Reduce = TFromD<D> (*)(D, VFromD<D>);
Func func = Add;
Reduce reduce = ReduceSum;
INIT(reduceFunc);
Init();
size_t i = 0;
const auto lanes = Lanes(d);
auto k = size - (size % lanes);
for (; i < k; i += lanes) {
auto nullVec = Zero(d);
if constexpr (std::is_same_v<OUT, double>) {
ScalableTag<uint64_t> dt;
nullVec = BitCast(d, PromoteTo(dt, LoadU(du, nulls + i)));
} else {
nullVec = PromoteTo(d, LoadU(du, nulls + i));
}
auto mask = Eq(nullVec, Zero(d));
auto vec = IfThenElse(mask, LoadFromDic<decltype(d), IN, OUT>(d, array, indexMap + i), Set(d, limitValue));
reduceVec = func(reduceVec, vec);
}
OUT buf[lanes];
std::fill(buf, buf + lanes, limitValue);
for (; i < size; i++) {
if (nulls[i] != 1) {
buf[size - i] = static_cast<OUT>(array[indexMap[i]]);
}
}
auto vec = LoadU(d, buf);
reduceVec = func(reduceVec, vec);
return reduce(d, reduceVec);
}
template <typename IN, typename OUT, typename FLAG, typename FLAG_HANDLER, ReduceFunc op>
void ReduceExternal(OUT *res_, FLAG &flag_, const IN *__restrict ptr, const size_t rowCount)
{
if constexpr (op == ReduceFunc::Sum) {
*res_ += Reduce<IN, OUT, op>(ptr, rowCount);
} else if constexpr (op == ReduceFunc::Max) {
*res_ = std::max(Reduce<IN, OUT, op>(ptr, rowCount), *res_);
} else {
*res_ = std::min(Reduce<IN, OUT, op>(ptr, rowCount), *res_);
}
FLAG_HANDLER::Update(flag_, rowCount);
}
template <typename IN, typename OUT, typename FLAG, typename FLAG_HANDLER, ReduceFunc op>
inline void ReduceWithNullsExternal(OUT *res_, FLAG &flag_, const IN *__restrict ptr, const size_t rowCount,
const uint8_t *__restrict condition)
{
if constexpr (op == ReduceFunc::Sum) {
*res_ += ReduceWithNulls<IN, OUT, ReduceFunc::Sum>(ptr, condition, rowCount);
} else if constexpr (op == ReduceFunc::Max) {
*res_ = std::max(ReduceWithNulls<IN, OUT, ReduceFunc::Max>(ptr, condition, rowCount), *res_);
} else {
*res_ = std::min(ReduceWithNulls<IN, OUT, ReduceFunc::Min>(ptr, condition, rowCount), *res_);
}
for (int64_t i = 0; i < rowCount; ++i) {
if (not*condition)
FLAG_HANDLER::Update(flag_, 1ULL);
++condition;
}
}
template <typename IN, typename OUT, typename FLAG, typename FLAG_HANDLER, ReduceFunc op>
void ReduceWithDicExternal(OUT *res_, FLAG &flag_, const IN *__restrict ptr, const size_t rowCount,
const int32_t *__restrict indexMap)
{
if constexpr (op == ReduceFunc::Sum) {
*res_ += ReduceWithDic<IN, OUT, ReduceFunc::Sum>(ptr, indexMap, rowCount);
} else if constexpr (op == ReduceFunc::Max) {
*res_ = std::max(ReduceWithDic<IN, OUT, ReduceFunc::Max>(ptr, indexMap, rowCount), *res_);
} else {
*res_ = std::min(ReduceWithDic<IN, OUT, ReduceFunc::Min>(ptr, indexMap, rowCount), *res_);
}
FLAG_HANDLER::Update(flag_, rowCount);
}
template <typename IN, typename OUT, typename FLAG, typename FLAG_HANDLER, ReduceFunc op>
void ReduceWithDicAndNullsExternal(OUT *res_, FLAG &flag_, const IN *__restrict ptr, const size_t rowCount,
const uint8_t *__restrict condition, const int32_t *__restrict indexMap)
{
if constexpr (op == ReduceFunc::Sum) {
*res_ += ReduceWithDicAndNulls<IN, OUT, ReduceFunc::Sum>(ptr, indexMap, condition, rowCount);
} else if constexpr (op == ReduceFunc::Max) {
*res_ = std::max(ReduceWithDicAndNulls<IN, OUT, ReduceFunc::Max>(ptr, indexMap, condition, rowCount), *res_);
} else {
*res_ = std::min(ReduceWithDicAndNulls<IN, OUT, ReduceFunc::Min>(ptr, indexMap, condition, rowCount), *res_);
}
for (int64_t i = 0; i < rowCount; ++i) {
if (not*condition)
FLAG_HANDLER::Update(flag_, 1);
++condition;
}
}
}
#endif