* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file digamma_common_basic_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/math/digamma/digamma_common_basic_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/math/digamma.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_DIGAMMA_DIGAMMA_COMMON_BASIC_IMPL_H__
#endif
#ifndef IMPL_MATH_DIGAMMA_DIGAMMA_COMMON_BASIC_IMPL_H
#define IMPL_MATH_DIGAMMA_DIGAMMA_COMMON_BASIC_IMPL_H
#include <cstdint>
#include "kernel_basic_intf.h"
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510 || __NPU_ARCH__ == 5102)
#include "digamma_3510_impl.h"
#elif defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201)
#include "digamma_v220_impl.h"
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 2002
#include "digamma_v200_impl.h"
#endif
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002)
namespace AscendC {
namespace {
constexpr float MIN_NEG_WITH_FLOAT = -8388608.0;
constexpr float DIGAMMA_PI = 3.141592653589793238f;
constexpr float DIGAMMA_NEG_PI = -3.141592653589793238f;
constexpr uint32_t DIGAMMA_FLOAT_NOREUSE_CALC_PROC = 7;
constexpr uint32_t DIGAMMA_FLOAT_REUSE_CALC_PROC = 6;
constexpr uint32_t DIGAMMA_HALF_CALC_PROC = 8;
constexpr size_t DIGAMMA_MAX_LOOP = 5;
constexpr float posCalcConst[] = {2.10927960927960927961e-2, 7.57575757575757575758e-3, 4.16666666666666666667e-3,
3.96825396825396825397e-3, 8.33333333333333333333e-3, 8.33333333333333333333e-2};
constexpr float tmp1CalcConst[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
constexpr float tmp1HalfCalcConst[] = {1.0, 2.0};
constexpr float picotCalcConst[] = {
0.00326538085938f, 0.0242919921875f, 0.053466796875f, 0.133377909660f, 0.333332300186f};
}
struct DigammaParams {
__aicore__ DigammaParams() {}
LocalTensor<float> result;
LocalTensor<float> tmpCal1;
LocalTensor<float> tmpCal2;
LocalTensor<float> tmpCal3;
LocalTensor<float> tmpCal4;
LocalTensor<float> tmpCal5;
LocalTensor<float> tmpScalar;
LocalTensor<uint8_t> mask;
LocalTensor<uint8_t> mask1;
LocalTensor<uint8_t> mask2;
UnaryRepeatParams unaryParams;
BinaryRepeatParams binaryParams;
uint32_t splitSize;
};
#pragma begin_pipe(V)
__aicore__ inline void DigammaGenCompareMask(
const LocalTensor<uint8_t>& mask, const LocalTensor<float>& src, DigammaParams& params, const float scalar,
CMPMODE cmpMode)
{
Duplicate<float, false>(params.tmpScalar, scalar, MASK_PLACEHOLDER, 1, DEFAULT_BLK_STRIDE, DEFAULT_REPEAT_STRIDE);
PipeBarrier<PIPE_V>();
uint8_t repeat = DivCeil(params.splitSize * sizeof(float), ONE_REPEAT_BYTE_SIZE);
Compare<float, uint8_t, false>(mask, src, params.tmpScalar, cmpMode, MASK_PLACEHOLDER, repeat, params.binaryParams);
PipeBarrier<PIPE_V>();
}
__aicore__ inline void DigammaGenNegIntMask(
const LocalTensor<uint8_t>& mask, const LocalTensor<float>& src, DigammaParams& params, const float scalar)
{
DigammaGenCompareMask(params.mask1, src, params, 0.0f, CMPMODE::LT);
DigammaGenCompareMask(params.mask2, src, params, MIN_NEG_WITH_FLOAT, CMPMODE::GT);
SetVectorMask<float>(0, ConstCeil(params.splitSize, sizeof(uint16_t) * ONE_BYTE_BIT_SIZE));
And<uint16_t, false>(
params.mask1.ReinterpretCast<uint16_t>(), params.mask1.ReinterpretCast<uint16_t>(),
params.mask2.ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
SetVectorMask<float>(0, params.splitSize);
DigammaCast(params.tmpCal1, src, RoundMode::CAST_ROUND);
uint8_t repeat = DivCeil(params.splitSize * sizeof(float), ONE_REPEAT_BYTE_SIZE);
Compare<float, uint8_t, false>(
params.mask2, src, params.tmpCal1, CMPMODE::EQ, MASK_PLACEHOLDER, repeat, params.binaryParams);
PipeBarrier<PIPE_V>();
SetVectorMask<float>(0, ConstCeil(params.splitSize, sizeof(uint16_t) * ONE_BYTE_BIT_SIZE));
And<uint16_t, false>(
mask.ReinterpretCast<uint16_t>(), params.mask1.ReinterpretCast<uint16_t>(),
params.mask2.ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
SetVectorMask<float>(0, params.splitSize);
}
__aicore__ inline void DigammaGenRangeMask(
const LocalTensor<uint8_t>& mask, const LocalTensor<float>& src, DigammaParams& params, const float min,
const float max)
{
DigammaGenCompareMask(params.mask1, src, params, max, CMPMODE::LT);
DigammaGenCompareMask(params.mask2, src, params, min, CMPMODE::GE);
SetVectorMask<float>(0, ConstCeil(params.splitSize, sizeof(uint16_t) * ONE_BYTE_BIT_SIZE));
And<uint16_t, false>(
mask.ReinterpretCast<uint16_t>(), params.mask1.ReinterpretCast<uint16_t>(),
params.mask2.ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
SetVectorMask<float>(0, params.splitSize);
}
__aicore__ inline void DigammaGenNanMask(
const LocalTensor<uint8_t>& mask, const LocalTensor<float>& src, DigammaParams& params)
{
DigammaGenCompareMask(params.mask1, src, params, 0.0f, CMPMODE::LT);
DigammaGenCompareMask(params.mask2, src, params, 0.0f, CMPMODE::GE);
SetVectorMask<float>(0, ConstCeil(params.splitSize, sizeof(uint16_t) * ONE_BYTE_BIT_SIZE));
Not<uint16_t, false>(
params.mask1.ReinterpretCast<uint16_t>(), params.mask1.ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, 1,
params.unaryParams);
Not<uint16_t, false>(
params.mask2.ReinterpretCast<uint16_t>(), params.mask2.ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, 1,
params.unaryParams);
PipeBarrier<PIPE_V>();
And<uint16_t, false>(
mask.ReinterpretCast<uint16_t>(), params.mask1.ReinterpretCast<uint16_t>(),
params.mask2.ReinterpretCast<uint16_t>(), MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
SetVectorMask<float>(0, params.splitSize);
}
__aicore__ inline void DigammaSelect(
const LocalTensor<float>& dst, const LocalTensor<float>& src, const LocalTensor<uint8_t>& mask,
const LocalTensor<float>& tmp, DigammaParams& params)
{
Duplicate<float, false>(params.tmpScalar, 0.0f, MASK_PLACEHOLDER, 1, DEFAULT_BLK_STRIDE, DEFAULT_REPEAT_STRIDE);
PipeBarrier<PIPE_V>();
SetCmpMask<float>(params.tmpScalar);
PipeBarrier<PIPE_V>();
Select<float, uint8_t>(tmp, mask, src, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
Add<float, false>(dst, tmp, dst, MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
}
__aicore__ inline void DigammaNegativeRange(
const LocalTensor<float>& dst, const LocalTensor<float>& src, DigammaParams& params)
{
DigammaCast(params.tmpScalar, src, RoundMode::CAST_FLOOR);
Sub<float, false>(params.tmpScalar, src, params.tmpScalar, MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
Muls<float, false>(params.tmpScalar, params.tmpScalar, DIGAMMA_PI, MASK_PLACEHOLDER, 1, params.unaryParams);
PipeBarrier<PIPE_V>();
CosCompute<float>(params.tmpCal3, params.tmpScalar, params.result, params.splitSize, true);
Muls<float, false>(src, src, DIGAMMA_NEG_PI, MASK_PLACEHOLDER, 1, params.unaryParams);
PipeBarrier<PIPE_V>();
SinCompute<float>(params.tmpScalar, src, params.result, params.splitSize, true);
Muls<float, false>(params.tmpCal3, params.tmpCal3, DIGAMMA_PI, MASK_PLACEHOLDER, 1, params.unaryParams);
PipeBarrier<PIPE_V>();
Div<float, false>(params.tmpCal3, params.tmpCal3, params.tmpScalar, MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
Sub<float, false>(dst, params.tmpCal2, params.tmpCal3, MASK_PLACEHOLDER, 1, params.binaryParams);
PipeBarrier<PIPE_V>();
}
template <bool isReuseSource = false>
__aicore__ inline void DigammaInitParams(
const LocalTensor<float>& tmp, const uint32_t& splitSize, const LocalTensor<half>& src, DigammaParams& params)
{
params.result = tmp;
params.tmpCal1 = params.result[splitSize];
params.tmpCal2 = params.tmpCal1[splitSize];
params.tmpCal3 = params.tmpCal2[splitSize];
params.tmpCal4 = params.tmpCal3[splitSize];
params.tmpCal5 = params.tmpCal4[splitSize];
params.tmpScalar = params.tmpCal5[splitSize];
params.mask = params.tmpScalar[splitSize].ReinterpretCast<uint8_t>();
params.mask1 = params.mask[splitSize];
params.mask2 = params.mask1[splitSize];
params.result.SetSize(splitSize * 4);
params.tmpCal1.SetSize(splitSize);
params.tmpCal2.SetSize(splitSize);
params.tmpCal3.SetSize(splitSize);
params.tmpCal4.SetSize(splitSize);
params.tmpCal5.SetSize(splitSize);
params.tmpScalar.SetSize(splitSize);
params.mask.SetSize(splitSize);
params.mask1.SetSize(splitSize);
params.mask2.SetSize(splitSize);
params.splitSize = splitSize;
}
template <bool isReuseSource = false>
__aicore__ inline void DigammaInitParams(
const LocalTensor<float>& tmp, const uint32_t& splitSize, const LocalTensor<float>& src, DigammaParams& params)
{
params.result = tmp;
params.tmpCal1 = tmp[splitSize];
params.tmpCal2 = params.tmpCal1[splitSize];
params.tmpCal3 = params.tmpCal2[splitSize];
if constexpr (isReuseSource) {
params.tmpCal4 = src;
params.tmpScalar = params.tmpCal3[splitSize];
} else {
params.tmpCal4 = params.tmpCal3[splitSize];
params.tmpScalar = params.tmpCal4[splitSize];
}
params.mask = params.tmpScalar[splitSize].ReinterpretCast<uint8_t>();
params.mask1 = params.mask[splitSize];
params.mask2 = params.mask1[splitSize];
params.result.SetSize(splitSize);
params.tmpCal1.SetSize(splitSize);
params.tmpCal2.SetSize(splitSize);
params.tmpCal3.SetSize(splitSize);
params.tmpCal4.SetSize(splitSize);
params.tmpScalar.SetSize(splitSize);
params.mask.SetSize(splitSize);
params.mask1.SetSize(splitSize);
params.mask2.SetSize(splitSize);
params.splitSize = splitSize;
}
#pragma end_pipe
}
#endif
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_DIGAMMA_DIGAMMA_COMMON_BASIC_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_MATH_DIGAMMA_DIGAMMA_COMMON_BASIC_IMPL_H__
#endif