/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/**
 * NOTE: Portions of this code were AI-generated and have been
 * technically reviewed for functional accuracy and security
 */
/*!
 * \file ndtri_compute.h
 * \brief Cephes Ndtri 分区间有理逼近的 Tensor 化实现(FP32 域)。
 *
 * 模块划分(与详细设计 §4.4 对齐):
 *   - _polevl(x)           : P(x) Horner 多项式
 *   - _plevl(x)            : Q(x) = 1 + ... 首项为 1 的变体
 *   - polevl_plevl(x)      : P(x) / Q(x) 有理函数
 *   - cal_p0(p)            : 中心区 y = sqrt(2π) * pm * (1 + pm^2 * P0(z)/Q0(z))
 *   - cal_sub(q)           : 尾部 x = sqrt(-2 ln q),x0 = x - ln(x)/x
 *   - cal_p12(x)           : 尾部修正 1/x * P12(1/x)/Q12(1/x)(按 x<8 / x>=8 掩码合并)
 *   - cal_tail(pSafe)      : 尾部 y_tail = sign * (x0 - cal_p12(x))
 *
 * 所有函数在 FP32 域工作。输入/输出 LocalTensor 均由调用者(Kernel 主体)管理 UB 分配,
 * 本文件仅做计算逻辑组合。
 */

#ifndef NDTRI_COMPUTE_H_
#define NDTRI_COMPUTE_H_

#include "kernel_operator.h"
#include "ndtri_coeffs.h"

namespace NsNdtri {

using namespace AscendC;

// ---------------------------------------------------------------
// Cephes polevl 约定(coefs[0] 为最高次项系数):
//   P(x) = coefs[0]*x^(n-1) + coefs[1]*x^(n-2) + ... + coefs[n-1]
// Horner: ans = coefs[0]; for i=1..n-1: ans = ans*x + coefs[i]
//
// 设计说明:
//   保留语义最清晰的原始 Mul+Adds 写法,交由编译器自动融合。
// ---------------------------------------------------------------
__aicore__ inline void PolEvl(
    const LocalTensor<float>& dst,
    const LocalTensor<float>& x,
    const float* coefs, int n,
    const LocalTensor<float>& scratch,
    int32_t len)
{
    // dst = coefs[0](最高次)
    Duplicate(dst, coefs[0], len);
    for (int i = 1; i < n; ++i) {
        // scratch = dst * x
        Mul(scratch, dst, x, len);
        // dst = scratch + coefs[i]
        Adds(dst, scratch, coefs[i], len);
    }
}

// ---------------------------------------------------------------
// Cephes p1evl 约定(首项系数为 1,不显式存入 coefs):
//   Q(x) = x^n + coefs[0]*x^(n-1) + coefs[1]*x^(n-2) + ... + coefs[n-1]
// 等价 Horner: ans = 1; for i=0..n-1: ans = ans*x + coefs[i]
// (FMA 自动融合同 PolEvl。)
// ---------------------------------------------------------------
__aicore__ inline void PlEvl(
    const LocalTensor<float>& dst,
    const LocalTensor<float>& x,
    const float* coefs, int n,
    const LocalTensor<float>& scratch,
    int32_t len)
{
    // dst = 1.0(隐式 x^n 项系数)
    Duplicate(dst, 1.0f, len);
    for (int i = 0; i < n; ++i) {
        Mul(scratch, dst, x, len);
        Adds(dst, scratch, coefs[i], len);
    }
}

// ---------------------------------------------------------------
// 有理函数 R(x) = P(x) / Q(x)(P 无首项约束,Q 首项为 1)
//   - tmpP / tmpQ: 存 P(x) / Q(x) 中间结果
//   - scratch    : Horner 内部 scratch
// ---------------------------------------------------------------
__aicore__ inline void PolEvlPlEvl(
    const LocalTensor<float>& dst,
    const LocalTensor<float>& x,
    const float* coefsP, int nP,
    const float* coefsQ, int nQ,
    const LocalTensor<float>& tmpP,
    const LocalTensor<float>& tmpQ,
    const LocalTensor<float>& scratch,
    int32_t len)
{
    PolEvl(tmpP, x, coefsP, nP, scratch, len);
    PlEvl (tmpQ, x, coefsQ, nQ, scratch, len);
    Div(dst, tmpP, tmpQ, len);
}

// ---------------------------------------------------------------
// cal_p0: 中心区
//   y = sqrt(2π) * (pm + pm^3 * R(z))
//     = sqrt(2π) * pm * (1 + z * R(z))
//   其中 pm = p - 0.5, z = pm^2
//
// Buffer 约定(由调用者传入,大小 = len * sizeof(float)):
//   - y       : 输出
//   - p       : 输入
//   - tmpPm   : pm 中间(可以复用 y 做输入→输出 inplace,不推荐;保持独立更清晰)
//   - tmpZ    : z 中间
//   - tmpP    : P0(z) 结果
//   - tmpQ    : Q0(z) 结果
//   - scratch : Horner scratch
// ---------------------------------------------------------------
__aicore__ inline void CalP0(
    const LocalTensor<float>& y,
    const LocalTensor<float>& p,
    const LocalTensor<float>& tmpPm,
    const LocalTensor<float>& tmpZ,
    const LocalTensor<float>& tmpP,
    const LocalTensor<float>& tmpQ,
    const LocalTensor<float>& scratch,
    int32_t len)
{
    // pm = p - 0.5
    Adds(tmpPm, p, -0.5f, len);

    // z = pm * pm
    Mul(tmpZ, tmpPm, tmpPm, len);

    // R = P0(z) / Q0(z)
    PolEvlPlEvl(y, tmpZ, LIST_P0, 5, LIST_Q0, 8, tmpP, tmpQ, scratch, len);
    // y 临时存 R(z)

    // y = z * R
    Mul(y, y, tmpZ, len);
    // y = 1 + z * R
    Adds(y, y, 1.0f, len);
    // y = pm * (1 + z * R)
    Mul(y, y, tmpPm, len);
    // y = sqrt(2π) * y
    Muls(y, y, NDTRI_SQRT_2PI, len);
}

// ---------------------------------------------------------------
// cal_sub: 尾部基础
//   x  = sqrt(-2 ln q)
//   x0 = x - ln(x) / x
//
// 输入 q ∈ (0, e^-2](由调用者在 cal_tail 中通过 q = select(mask_neg, 1 - pSafe, pSafe) 保证),
// pSafe 钳制已确保 q > 0。
//
// Buffer 约定:
//   - x0   : 输出 x0
//   - xOut : 输出 x(供 cal_p12 使用)
//   - q    : 输入
//   - tmp  : 工作 buffer
// ---------------------------------------------------------------
__aicore__ inline void CalSub(
    const LocalTensor<float>& x0,
    const LocalTensor<float>& xOut,
    const LocalTensor<float>& q,
    const LocalTensor<float>& tmp,
    int32_t len)
{
    // tmp = ln(q)
    Ln(tmp, q, len);
    // tmp = -2 * ln(q)
    Muls(tmp, tmp, -2.0f, len);
    // xOut = sqrt(-2 ln q)
    Sqrt(xOut, tmp, len);
    // tmp = ln(xOut)
    Ln(tmp, xOut, len);
    // tmp = ln(x) / x
    Div(tmp, tmp, xOut, len);
    // x0 = x - ln(x) / x
    Sub(x0, xOut, tmp, len);
}

// ---------------------------------------------------------------
// cal_p12: 尾部修正
//   z = 1 / x
//   r1 = z * P1(z) / Q1(z)  (对 x < 8 使用)
//   r2 = z * P2(z) / Q2(z)  (对 x >= 8 使用)
//   corr = select(x < 8, r1, r2)
//
// Buffer 约定:
//   - corr     : 输出
//   - x        : 输入 x = sqrt(-2 ln q)
//   - tmpZ     : z = 1/x
//   - tmpR1    : r1 = P1(z)/Q1(z)
//   - tmpR2    : r2 = P2(z)/Q2(z)
//   - tmpP     : Horner 多项式 P(x) 结果
//   - tmpQ     : Horner 多项式 Q(x) 结果
//   - scratch  : Horner scratch
//   - maskX    : uint8 mask buffer
//
// ISSUE-001:调用者传入的 len 必须是 64 倍数(FP32 下 256B 对齐),
// 由 Kernel 层的 lenAligned 保证。
// ---------------------------------------------------------------
__aicore__ inline void CalP12(
    const LocalTensor<float>& corr,
    const LocalTensor<float>& x,
    const LocalTensor<float>& tmpZ,
    const LocalTensor<float>& tmpR1,
    const LocalTensor<float>& tmpR2,
    const LocalTensor<float>& tmpP,
    const LocalTensor<float>& tmpQ,
    const LocalTensor<float>& scratch,
    const LocalTensor<uint8_t>& maskX,
    int32_t len)
{
    // z = 1 / x  =>  tmpZ = 1.0, tmpZ /= x
    Duplicate(tmpZ, 1.0f, len);
    Div(tmpZ, tmpZ, x, len);

    // r1 = P1(z) / Q1(z)
    PolEvlPlEvl(tmpR1, tmpZ, LIST_P1, 9, LIST_Q1, 8, tmpP, tmpQ, scratch, len);

    // r2 = P2(z) / Q2(z)
    PolEvlPlEvl(tmpR2, tmpZ, LIST_P2, 9, LIST_Q2, 8, tmpP, tmpQ, scratch, len);

    // mask: x < 8  -> 选 r1,否则 r2
    CompareScalar(maskX, x, NDTRI_X_BOUNDARY, CMPMODE::LT, len);

    // corr_raw = select(mask, r1, r2)
    Select(corr, maskX, tmpR1, tmpR2,
           SELMODE::VSEL_TENSOR_TENSOR_MODE, len);

    // corr = z * corr_raw
    Mul(corr, corr, tmpZ, len);
}

// ---------------------------------------------------------------
// cal_tail: 尾部完整流程
//   q    = select(maskNeg, 1 - pSafe, pSafe)
//   x    = sqrt(-2 ln q)
//   x0   = x - ln(x)/x
//   corr = cal_p12(x)
//   base = x0 - corr         (Cephes 源码 x0 -= ...)
//   y_tail = select(maskNeg, +base, -base)
//
// Buffer 约定:
//   - yTail     : 输出
//   - pSafe     : 输入(已经 clamp 到 [FLT_MIN, 1-FLT_MIN])
//   - maskNeg   : p >= 0.5 的掩码
//   - tmpQ      : q 中间(复用为 "1 - pSafe")
//   - tmpX      : x
//   - tmpX0     : x0
//   - tmpCorr   : corr
//   - tmp1..5   : 5 个 fp32 scratch buffer(供 cal_sub / cal_p12 使用)
//   - maskX     : uint8 scratch mask
// ---------------------------------------------------------------
__aicore__ inline void CalTail(
    const LocalTensor<float>& yTail,
    const LocalTensor<float>& pSafe,
    const LocalTensor<uint8_t>& maskNeg,
    const LocalTensor<float>& tmpQ,
    const LocalTensor<float>& tmpX,
    const LocalTensor<float>& tmpX0,
    const LocalTensor<float>& tmpCorr,
    const LocalTensor<float>& tmp1,  // cal_sub 的 tmp / cal_p12 的 tmpZ
    const LocalTensor<float>& tmp2,  // cal_p12 的 tmpR1
    const LocalTensor<float>& tmp3,  // cal_p12 的 tmpR2
    const LocalTensor<float>& tmp4,  // cal_p12 的 tmpP
    const LocalTensor<float>& tmp5,  // cal_p12 的 tmpQ / cal_p12 的 scratch
    const LocalTensor<uint8_t>& maskX,
    int32_t len)
{
    // Step 1: q = select(maskNeg, 1 - pSafe, pSafe)
    //   oneMinusP = 1 - pSafe
    Muls(tmpQ, pSafe, -1.0f, len);
    Adds(tmpQ, tmpQ, 1.0f, len);
    // Select: maskNeg=1 -> tmpQ (1-pSafe), maskNeg=0 -> pSafe
    Select(tmpQ, maskNeg, tmpQ, pSafe,
           SELMODE::VSEL_TENSOR_TENSOR_MODE, len);

    // Step 2: x = sqrt(-2 ln q),x0 = x - ln(x)/x
    //   CalSub 使用 tmp1 作为工作 buffer
    CalSub(tmpX0, tmpX, tmpQ, tmp1, len);

    // Step 3: corr = cal_p12(x)
    //   CalP12 内部 Horner 需要 tmpP / tmpQ / scratch:复用 tmp4 / tmp5 / tmpQ
    //   注意:tmpQ 在此时已经不再需要(q 在 Step 2 中已经消费)
    CalP12(tmpCorr, tmpX,
           /*tmpZ  */tmp1,
           /*tmpR1 */tmp2,
           /*tmpR2 */tmp3,
           /*tmpP  */tmp4,
           /*tmpQ  */tmp5,
           /*scratch*/tmpQ,
           maskX, len);

    // Step 4: base = x0 - corr
    Sub(tmpX0, tmpX0, tmpCorr, len);

    // Step 5: sign:
    //   - p <  0.5 (maskNeg=0)  -> y_tail = -base
    //   - p >= 0.5 (maskNeg=1)  -> y_tail = +base
    Muls(tmpCorr, tmpX0, -1.0f, len);  // -base 存 tmpCorr
    // Select: maskNeg=1 -> +base, maskNeg=0 -> -base
    Select(yTail, maskNeg, tmpX0, tmpCorr,
           SELMODE::VSEL_TENSOR_TENSOR_MODE, len);
}

} // namespace NsNdtri

#endif  // NDTRI_COMPUTE_H_