* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
* \file ndtri_kernel.h
* \brief Ndtri Kernel 实现(arch35 / Ascend950)
*
* 公式:
* y = ndtri(p) = sqrt(2) * erf^{-1}(2p - 1)
* 对齐 PyTorch torch.special.ndtri / SciPy scipy.special.ndtri(Cephes 算法)。
*
* 计算流(单 tile):
* Step 0: 输入 Cast→fp32 (fp16/bf16: CAST_NONE, fp32: ReinterpretCast)
* Step 1: 构造 maskTail / maskNeg / maskSpecial
* Step 2: pSafe = clamp(p, FLT_MIN, 1 - FLT_MIN)
* Step 3: yTail = cal_tail(pSafe, maskNeg)
* Step 4: yCenter = cal_p0(pSafe)
* Step 5: y = select(maskTail, yTail, yCenter)
* Step 6: y = select(maskSpecial, ySpecial, y)
* Step 7: 输出 Cast→T (fp16/bf16: CAST_RINT, fp32: ReinterpretCast)
*
* 迭代二范围(本次整合):
* - FP32 / FP16 / BF16 × 对齐/非对齐 共 6 个 TilingKey 真实实现
* - FP16 路径:Cast fp16→fp32 → 统一算法 → Cast fp32→fp16 (穿刺 P-2 已验证 bit-exact)
* - BF16 路径:Cast bf16→fp32 → 统一算法 → Cast fp32→bf16 (穿刺 P-3 已验证 bit-exact)
* - 非对齐路径:DataCopyPad 处理尾块(本来就用 DataCopyPad,天然兼容)
*
* TilingKey 矩阵:{fp32, fp16, bf16} × {对齐, 非对齐} = 6 个
*/
#ifndef NDTRI_KERNEL_H
#define NDTRI_KERNEL_H
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "ndtri_tiling_data.h"
#include "ndtri_tiling_key.h"
#include "ndtri_coeffs.h"
#include "ndtri_compute.h"
namespace NsNdtri {
using namespace AscendC;
template <typename T, int K_ALIGN>
class Ndtri {
static constexpr int32_t BUFFER_NUM = 2;
static constexpr bool IS_FP32 = AscendC::IsSameType<T, float>::value;
static constexpr int32_t CMP_ALIGN_ELEM = 64;
__aicore__ inline static int32_t AlignCmpLen(int32_t len)
{
return (len + CMP_ALIGN_ELEM - 1) / CMP_ALIGN_ELEM * CMP_ALIGN_ELEM;
}
public:
__aicore__ inline Ndtri() = default;
__aicore__ inline void Init(
GM_ADDR self, GM_ADDR out,
const NdtriTilingData* tilingData);
__aicore__ inline void Process();
private:
__aicore__ inline void CopyIn(int64_t progress, int64_t currentNum);
__aicore__ inline void Compute(int64_t currentNum);
__aicore__ inline void CopyOut(int64_t progress, int64_t currentNum);
__aicore__ inline void BuildMasks(
const LocalTensor<float>& p,
const LocalTensor<uint8_t>& maskTail,
const LocalTensor<uint8_t>& maskNeg,
const LocalTensor<uint8_t>& maskSpecial,
const LocalTensor<float>& scratch,
int32_t len);
__aicore__ inline void BuildSpecialY(
const LocalTensor<float>& ySpecial,
const LocalTensor<float>& p,
const LocalTensor<float>& scratch,
int32_t len);
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> inQueSelf;
TQue<QuePosition::VECOUT, BUFFER_NUM> outQueY;
TBuf<TPosition::VECCALC> pBuf;
TBuf<TPosition::VECCALC> yBuf;
TBuf<TPosition::VECCALC> tmpBuf0;
TBuf<TPosition::VECCALC> tmpBuf1;
TBuf<TPosition::VECCALC> tmpBuf2;
TBuf<TPosition::VECCALC> tmpBuf3;
TBuf<TPosition::VECCALC> tmpBuf4;
TBuf<TPosition::VECCALC> tmpBuf5;
TBuf<TPosition::VECCALC> tmpBuf6;
TBuf<TPosition::VECCALC> tmpBuf7;
TBuf<TPosition::VECCALC> tmpBuf8;
TBuf<TPosition::VECCALC> tmpBuf9;
TBuf<TPosition::VECCALC> tmpBuf10;
TBuf<TPosition::VECCALC> maskBuf0;
TBuf<TPosition::VECCALC> maskBuf1;
TBuf<TPosition::VECCALC> maskBuf2;
TBuf<TPosition::VECCALC> maskBuf3;
GlobalTensor<T> selfGm;
GlobalTensor<T> outGm;
int64_t blockLength_ = 0;
int64_t ubLength_ = 0;
};
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::Init(
GM_ADDR self, GM_ADDR out,
const NdtriTilingData* tilingData)
{
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t remainderLength = tilingData->totalNum - tilingData->blockFactor * blockIdx;
blockLength_ = (remainderLength > tilingData->blockFactor) ?
tilingData->blockFactor : remainderLength;
if (blockLength_ < 0) {
blockLength_ = 0;
}
ubLength_ = tilingData->ubFactor;
if (ubLength_ <= 0) {
ubLength_ = 1;
}
int64_t offset = tilingData->blockFactor * blockIdx;
selfGm.SetGlobalBuffer((__gm__ T*)self + offset, blockLength_);
outGm.SetGlobalBuffer((__gm__ T*)out + offset, blockLength_);
pipe.InitBuffer(inQueSelf, BUFFER_NUM, ubLength_ * sizeof(T));
pipe.InitBuffer(outQueY, BUFFER_NUM, ubLength_ * sizeof(T));
pipe.InitBuffer(pBuf, ubLength_ * sizeof(float));
pipe.InitBuffer(yBuf, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf0, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf1, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf2, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf3, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf4, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf5, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf6, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf7, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf8, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf9, ubLength_ * sizeof(float));
pipe.InitBuffer(tmpBuf10, ubLength_ * sizeof(float));
int64_t maskBytes = (ubLength_ + 7) / 8 + 32;
pipe.InitBuffer(maskBuf0, maskBytes);
pipe.InitBuffer(maskBuf1, maskBytes);
pipe.InitBuffer(maskBuf2, maskBytes);
pipe.InitBuffer(maskBuf3, maskBytes);
}
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::Process()
{
if (blockLength_ <= 0) {
return;
}
int64_t loopCount = (blockLength_ + ubLength_ - 1) / ubLength_;
for (int64_t i = 0; i < loopCount; ++i) {
int64_t currentNum = (i == loopCount - 1) ?
(blockLength_ - ubLength_ * i) : ubLength_;
CopyIn(i, currentNum);
Compute(currentNum);
CopyOut(i, currentNum);
}
}
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::CopyIn(int64_t progress, int64_t currentNum)
{
LocalTensor<T> inLocal = inQueSelf.template AllocTensor<T>();
DataCopyExtParams copyParams{
1, static_cast<uint32_t>(currentNum * sizeof(T)), 0, 0, 0};
DataCopyPadExtParams<T> padParams{false, 0, 0, 0};
int64_t gmOffset = progress * ubLength_;
DataCopyPad(inLocal, selfGm[gmOffset], copyParams, padParams);
inQueSelf.EnQue(inLocal);
}
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::BuildMasks(
const LocalTensor<float>& p,
const LocalTensor<uint8_t>& maskTail,
const LocalTensor<uint8_t>& maskNeg,
const LocalTensor<uint8_t>& maskSpecial,
const LocalTensor<float>& scratch,
int32_t len)
{
Adds(scratch, p, -0.5f, len);
Abs(scratch, scratch, len);
CompareScalar(maskTail, scratch,
0.5f - NDTRI_VAL_SUB, CMPMODE::GE, len);
constexpr uint32_t NEG_INF_BITS_U = 0xFF800000U;
constexpr uint32_t POS_INF_BITS_U = 0x7F800000U;
float negInf, posInf;
{
union { uint32_t u; float f; } cvt;
cvt.u = NEG_INF_BITS_U; negInf = cvt.f;
cvt.u = POS_INF_BITS_U; posInf = cvt.f;
}
CompareScalar(maskSpecial, p, 0.0f, CMPMODE::LE, len);
CompareScalar(maskNeg, p, 1.0f, CMPMODE::GE, len);
Or(maskSpecial, maskSpecial, maskNeg, len);
Maxs(scratch, p, negInf, len);
Mins(scratch, scratch, posInf, len);
Compare(maskNeg, p, scratch, CMPMODE::NE, len);
Or(maskSpecial, maskSpecial, maskNeg, len);
CompareScalar(maskNeg, p, 0.5f, CMPMODE::GE, len);
}
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::BuildSpecialY(
const LocalTensor<float>& ySpecial,
const LocalTensor<float>& p,
const LocalTensor<float>& scratch,
int32_t len)
{
constexpr uint32_t NAN_BITS = 0x7FC00000U;
constexpr uint32_t POS_INF_BITS = 0x7F800000U;
constexpr uint32_t NEG_INF_BITS = 0xFF800000U;
float nanVal, posInf, negInf;
{
union { uint32_t u; float f; } cvt;
cvt.u = NAN_BITS; nanVal = cvt.f;
cvt.u = POS_INF_BITS; posInf = cvt.f;
cvt.u = NEG_INF_BITS; negInf = cvt.f;
}
Duplicate(ySpecial, nanVal, len);
LocalTensor<uint8_t> maskEq = maskBuf3.Get<uint8_t>();
CompareScalar(maskEq, p, 0.0f, CMPMODE::EQ, len);
Duplicate(scratch, negInf, len);
Select(ySpecial, maskEq, scratch, ySpecial,
SELMODE::VSEL_TENSOR_TENSOR_MODE, len);
CompareScalar(maskEq, p, 1.0f, CMPMODE::EQ, len);
Duplicate(scratch, posInf, len);
Select(ySpecial, maskEq, scratch, ySpecial,
SELMODE::VSEL_TENSOR_TENSOR_MODE, len);
}
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::Compute(int64_t currentNum)
{
LocalTensor<T> inLocal = inQueSelf.template DeQue<T>();
LocalTensor<T> outLocal = outQueY.template AllocTensor<T>();
int32_t len = static_cast<int32_t>(currentNum);
int32_t lenAligned = AlignCmpLen(len);
LocalTensor<float> p = pBuf.Get<float>();
LocalTensor<float> y = yBuf.Get<float>();
if (lenAligned > len) {
Duplicate(p, 0.5f, lenAligned);
}
if constexpr (IS_FP32) {
LocalTensor<float> inFp32 = inLocal.template ReinterpretCast<float>();
Adds(p, inFp32, 0.0f, len);
} else {
Cast(p, inLocal, RoundMode::CAST_NONE, len);
}
LocalTensor<float> tmpPm = tmpBuf0.Get<float>();
LocalTensor<float> tmpZ = tmpBuf1.Get<float>();
LocalTensor<float> tmpP = tmpBuf2.Get<float>();
LocalTensor<float> tmpQ = tmpBuf3.Get<float>();
LocalTensor<float> scratch = tmpBuf4.Get<float>();
LocalTensor<float> pSafe = tmpBuf5.Get<float>();
LocalTensor<float> yCenter = tmpBuf6.Get<float>();
LocalTensor<float> yTail = tmpBuf7.Get<float>();
LocalTensor<float> ySpecial = tmpBuf8.Get<float>();
LocalTensor<uint8_t> maskTail = maskBuf0.Get<uint8_t>();
LocalTensor<uint8_t> maskNeg = maskBuf1.Get<uint8_t>();
LocalTensor<uint8_t> maskSpecial = maskBuf2.Get<uint8_t>();
BuildMasks(p, maskTail, maskNeg, maskSpecial, scratch, lenAligned);
Maxs(pSafe, p, NDTRI_SAFE_LO, lenAligned);
Mins(pSafe, pSafe, 1.0f - NDTRI_SAFE_LO, lenAligned);
CalTail(yTail, pSafe, maskNeg,
tmpPm,
tmpZ,
tmpP,
tmpQ,
scratch,
ySpecial,
yCenter,
tmpBuf9.Get<float>(),
tmpBuf10.Get<float>(),
maskBuf3.Get<uint8_t>(),
lenAligned);
CalP0(yCenter, pSafe, tmpPm, tmpZ, tmpP, tmpQ, scratch, lenAligned);
Select(y, maskTail, yTail, yCenter,
SELMODE::VSEL_TENSOR_TENSOR_MODE, lenAligned);
BuildSpecialY(ySpecial, p, scratch, lenAligned);
Select(y, maskSpecial, ySpecial, y,
SELMODE::VSEL_TENSOR_TENSOR_MODE, lenAligned);
if constexpr (IS_FP32) {
LocalTensor<float> outFp32 = outLocal.template ReinterpretCast<float>();
Adds(outFp32, y, 0.0f, len);
} else {
Cast(outLocal, y, RoundMode::CAST_RINT, len);
}
outQueY.template EnQue<T>(outLocal);
inQueSelf.FreeTensor(inLocal);
}
template <typename T, int K_ALIGN>
__aicore__ inline void Ndtri<T, K_ALIGN>::CopyOut(int64_t progress, int64_t currentNum)
{
LocalTensor<T> outLocal = outQueY.template DeQue<T>();
DataCopyExtParams copyParams{
1, static_cast<uint32_t>(currentNum * sizeof(T)), 0, 0, 0};
int64_t gmOffset = progress * ubLength_;
DataCopyPad(outGm[gmOffset], outLocal, copyParams);
outQueY.FreeTensor(outLocal);
}
}
#endif