/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* Generated By CANNBot */

/*!
 * \file acosh.h
 * \brief Acosh 算子 Kernel 类(arch35 / Ascend950)
 *
 * 与 DESIGN.md v2.1 §3.5 对齐:
 *   - 13 步数值稳定 acosh 公式
 *   - FP32 直通 / FP16/BF16 入口 CAST_NONE → 13 步 FP32 → 出口 CAST_RINT
 *   - Double Buffer (BUFFER_NUM=2)
 *   - Buffer 复用:dataTBuf (step1→6/10) / dataRBuf (step5→6/11/8/9/12) / logTmpBuf (step13a/b/c)
 *   - Log natural 三参数调用(不接收 sharedTmpBuffer),由框架自动从未 InitBuffer 的剩余 UB 申请
 *
 * 迭代一范围(FP32 主线骨架):
 *   - FP32 路径完整实现;FP16/BF16 路径在迭代二落地(这里预留 else 分支编译占位)
 */
#ifndef NSACOSH_ACOSH_H
#define NSACOSH_ACOSH_H

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "acosh_tiling_data.h"
#include "acosh_tiling_key.h"

namespace NsAcosh {
using namespace AscendC;

// ============================================================
// 13 步流程的 FP32 常量(requirement §8.1 强约束 FP32 字面量精度)
// ============================================================
constexpr float CONST_NEG_ONE     = -1.0f;
constexpr float CONST_ONE         =  1.0f;
constexpr float CONST_S_MIN       =  1.0e-45f;       // clip 下界(requirement §8.1)
constexpr float CONST_S_MAX       =  3.4028235e34f;  // clip 上界(requirement §8.1)
// ln(2) FP32 近似 ≈ 0.6931(requirement v1.2 §8.1 + DESIGN v2.1 §3.5.1 / §5.1 R5 已修正)
constexpr float CONST_LN2_ADD     =  0.693147180559945286227f;

// Double Buffer 固定为 2(13 步含 Log/Sqrt/Div 计算密集,双缓冲收益显著)
static constexpr int32_t BUFFER_NUM = 2;

template <typename T>
class Acosh {
public:
    __aicore__ inline Acosh() {}

    __aicore__ inline void Init(GM_ADDR self, GM_ADDR out, const AcoshTilingData* tilingData);
    __aicore__ inline void Process();

private:
    __aicore__ inline void CopyIn(int64_t progress, int64_t currentNum);
    __aicore__ inline void Compute(int64_t currentNum);
    __aicore__ inline void CopyOut(int64_t progress, int64_t currentNum);

    // 13 步 FP32 计算(支持 xFp32 == yFp32 别名调用,FP16/BF16 路径下二者同为 fp32WorkBuf)
    __aicore__ inline void ComputeFp32Pipeline(LocalTensor<float>& xFp32,
                                                LocalTensor<float>& yFp32,
                                                int64_t count);

private:
    TPipe pipe;
    TQue<QuePosition::VECIN,  BUFFER_NUM> inputQue;     // GM → UB 输入队列(dtype = T)
    TQue<QuePosition::VECOUT, BUFFER_NUM> outputQue;    // UB → GM 输出队列(dtype = T)
    // FP32 工作 Buffer:FP16/BF16 路径承担 xFp32==yFp32==fp32Work 三重角色;FP32 路径下未必用到
    TBuf<QuePosition::VECCALC> fp32WorkBuf;
    // 13 步 Buffer 复用
    TBuf<QuePosition::VECCALC> dataTBuf;
    TBuf<QuePosition::VECCALC> dataRBuf;
    // step 13a/b 暂存 log(x)+ln(2);step 13c 读取
    // 注:Log natural API(base-e)`Log(dst, src, calCount)` 不接收 sharedTmpBuffer,
    //     其内部所需临时空间由框架自动从「未 InitBuffer 的剩余 UB 空间」申请。
    //     Host Tiling 已通过 GetLogMaxMinTmpSize 在 ubFactor 计算中扣除该字节,
    //     保证 InitBuffer 完成后剩余 UB 充足。
    TBuf<QuePosition::VECCALC> logTmpBuf;

    GlobalTensor<T> inputGM;
    GlobalTensor<T> outputGM;

    int64_t blockLength_ = 0;   // 本核要处理的元素总数(可能 < blockFactor,例如尾核)
    int64_t ubFactor_    = 0;   // 单次 UB 循环处理元素数
};

// ============================================================
// Init: 设置 GM 偏移 + 申请 UB Buffer(共 8 个 FP32 当量 Buffer 含 DB)
// ============================================================
template <typename T>
__aicore__ inline void Acosh<T>::Init(GM_ADDR self, GM_ADDR out, const AcoshTilingData* tilingData)
{
    int64_t blockIdx = AscendC::GetBlockIdx();
    // blockFactor 已在 Host Tiling 中按 32B 向上对齐到 ≤ totalNum 上限;
    // blockIdx ≤ usedCoreNum-1,乘积 ≤ totalNum ≤ INT64_MAX,无溢出风险。
    int64_t remainder = tilingData->totalNum - tilingData->blockFactor * blockIdx;
    blockLength_ = (remainder > tilingData->blockFactor) ? tilingData->blockFactor : remainder;
    ubFactor_    = tilingData->ubFactor;

    inputGM.SetGlobalBuffer((__gm__ T*)self + tilingData->blockFactor * blockIdx, blockLength_);
    outputGM.SetGlobalBuffer((__gm__ T*)out  + tilingData->blockFactor * blockIdx, blockLength_);

    pipe.InitBuffer(inputQue,    BUFFER_NUM, ubFactor_ * sizeof(T));
    pipe.InitBuffer(outputQue,   BUFFER_NUM, ubFactor_ * sizeof(T));
    // FP32 工作区:仅 FP16/BF16 路径用于承担 xFp32==yFp32==fp32Work 三重角色
    // FP32 路径下 Compute() 直接对 xLocal/yLocal 执行 13 步流水,fp32WorkBuf 不参与计算 → 跳过 InitBuffer
    // 注:BufferCount FP32 路径 8 → 7(节省 ubFactor*4B),FP16/BF16 路径保持 8
    if constexpr (!std::is_same_v<T, float>) {
        pipe.InitBuffer(fp32WorkBuf, ubFactor_ * sizeof(float));
    }
    // 13 步 Buffer 复用:data_t / data_r
    pipe.InitBuffer(dataTBuf,    ubFactor_ * sizeof(float));
    pipe.InitBuffer(dataRBuf,    ubFactor_ * sizeof(float));
    // step 13b 的 log(x)+ln(2) 暂存(非 Log 隐式 tmpBuffer 用途;后者由框架在剩余 UB 自动申请)
    pipe.InitBuffer(logTmpBuf,   ubFactor_ * sizeof(float));
}

// ============================================================
// Process: 多次循环,每次处理 ubFactor 个元素
// ============================================================
template <typename T>
__aicore__ inline void Acosh<T>::Process()
{
    if (blockLength_ <= 0) {
        return;   // 尾核可能没有数据
    }
    int64_t loopCount = (blockLength_ + ubFactor_ - 1) / ubFactor_;
    for (int64_t i = 0; i < loopCount; i++) {
        int64_t currentNum = (i == (loopCount - 1)) ? (blockLength_ - ubFactor_ * i) : ubFactor_;
        CopyIn(i, currentNum);
        Compute(currentNum);
        CopyOut(i, currentNum);
    }
}

// ============================================================
// CopyIn: GM → UB(DataCopyPad 自动处理非对齐尾块)
// ============================================================
template <typename T>
__aicore__ inline void Acosh<T>::CopyIn(int64_t progress, int64_t currentNum)
{
    LocalTensor<T> xLocal = inputQue.template AllocTensor<T>();
    AscendC::DataCopyExtParams copyParams;
    copyParams.blockCount = 1;
    copyParams.blockLen   = static_cast<uint32_t>(currentNum * sizeof(T));
    copyParams.srcStride  = 0;
    copyParams.dstStride  = 0;
    AscendC::DataCopyPad(xLocal, inputGM[progress * ubFactor_], copyParams, {false, 0, 0, 0});
    inputQue.EnQue(xLocal);
}

// ============================================================
// CopyOut: UB → GM
// ============================================================
template <typename T>
__aicore__ inline void Acosh<T>::CopyOut(int64_t progress, int64_t currentNum)
{
    LocalTensor<T> yLocal = outputQue.template DeQue<T>();
    AscendC::DataCopyExtParams copyParams;
    copyParams.blockCount = 1;
    copyParams.blockLen   = static_cast<uint32_t>(currentNum * sizeof(T));
    copyParams.srcStride  = 0;
    copyParams.dstStride  = 0;
    AscendC::DataCopyPad(outputGM[progress * ubFactor_], yLocal, copyParams);
    outputQue.FreeTensor(yLocal);
}

// ============================================================
// Compute: 入口/出口 Cast + 13 步 FP32 流水
// FP32 路径直通;FP16/BF16 路径入口 CAST_NONE → 13 步 → 出口 CAST_RINT
// ============================================================
template <typename T>
__aicore__ inline void Acosh<T>::Compute(int64_t currentNum)
{
    LocalTensor<T> xLocal = inputQue.template DeQue<T>();
    LocalTensor<T> yLocal = outputQue.template AllocTensor<T>();

    if constexpr (std::is_same_v<T, float>) {
        // ----------------- FP32 主线(迭代一目标) -----------------
        // 直接对 xLocal/yLocal(已是 float)执行 13 步流水
        ComputeFp32Pipeline(xLocal, yLocal, currentNum);
    } else {
        // ----------------- FP16/BF16 路径(迭代二完整实现) -----------------
        // Ascend950PR/DT: half/bfloat16_t → float 仅支持 CAST_NONE(Cast.md 表 6)
        //                  float → half/bfloat16_t 支持 CAST_RINT(不支持 CAST_NONE)
        LocalTensor<float> fp32Work = fp32WorkBuf.Get<float>();
        AscendC::Cast(fp32Work, xLocal, AscendC::RoundMode::CAST_NONE, currentNum);
        ComputeFp32Pipeline(fp32Work, fp32Work, currentNum);   // xFp32 == yFp32 别名调用
        AscendC::Cast(yLocal, fp32Work, AscendC::RoundMode::CAST_RINT, currentNum);
    }

    outputQue.template EnQue<T>(yLocal);
    inputQue.FreeTensor(xLocal);
}

// ============================================================
// ComputeFp32Pipeline: 严格按 DESIGN v2.1 §3.5.2 13 步流程实现
//
// Buffer 生命周期(详见 DESIGN §3.5.2 表 Buffer 生命周期表):
//   xFp32 仅在 step 13a / step 1 读取;之后不再读 → 支持 xFp32==yFp32 别名调用
//   yFp32 在 step 2 首次写入;step 7/10/13c 写;step 8/13c 读
//   dataT: step 1 写 → step 3/6 读 → step 10/11/12 写
//   dataR: step 3/4/5/6 → step 11/8/9/12 复用
//   logTmp: step 13a/b 写 → step 13c 读
// ============================================================
template <typename T>
__aicore__ inline void Acosh<T>::ComputeFp32Pipeline(
    LocalTensor<float>& xFp32,
    LocalTensor<float>& yFp32,
    int64_t count)
{
    LocalTensor<float> dataT  = dataTBuf.Get<float>();   // data_t 备份
    LocalTensor<float> dataR  = dataRBuf.Get<float>();   // data_r 备份
    LocalTensor<float> logTmp = logTmpBuf.Get<float>();  // log(x)+ln(2) 暂存
    uint32_t n = static_cast<uint32_t>(count);

    // -------- Step 13a/13b 提前算:data_s1 = log(x) + ln(2),存到 logTmp --------
    // 提前算的目的:xFp32 在 step 1 后会被复用(FP32 路径不会,FP16/BF16 别名路径会)
    //               step 13a/13b 不修改 xFp32,安全
    AscendC::Log(logTmp, xFp32, n);                       // logTmp = ln(x)
    AscendC::Adds(logTmp, logTmp, CONST_LN2_ADD, n);      // logTmp = ln(x) + ln(2)

    // -------- Step 1: data_t = x - 1 --------
    AscendC::Adds(dataT, xFp32, CONST_NEG_ONE, n);

    // -------- Step 2: yFp32 = 2 * (x - 1) (此后 xFp32 不再读,alias 安全释放) --------
    AscendC::Add(yFp32, dataT, dataT, n);

    // -------- Step 3: dataR = data_t * data_t = (x-1)² --------
    AscendC::Mul(dataR, dataT, dataT, n);

    // -------- Step 4: dataR += yFp32 → dataR = (x-1)² + 2(x-1) = x² - 1 --------
    AscendC::Add(dataR, dataR, yFp32, n);

    // -------- Step 5: dataR = sqrt(dataR) = √(x²-1) --------
    AscendC::Sqrt(dataR, dataR, n);

    // -------- Step 6: dataR = data_t + dataR = (x-1) + √(x²-1) --------
    AscendC::Add(dataR, dataT, dataR, n);

    // -------- Step 7: yFp32 = dataR + 1 = u = x + √(x²-1) --------
    AscendC::Adds(yFp32, dataR, CONST_ONE, n);

    // -------- Step 10: dataT = log(u)(dataT Buffer 在 step 6 之后可释放复用) --------
    AscendC::Log(dataT, yFp32, n);

    // -------- Step 11: dataT = log(u) * dataR = log(u) × ((x-1)+√(x²-1)) --------
    AscendC::Mul(dataT, dataT, dataR, n);

    // -------- Step 8: dataR = u - 1 = s(未 clip)(dataR Buffer 在 step 11 之后释放复用) --------
    AscendC::Adds(dataR, yFp32, CONST_NEG_ONE, n);

    // -------- Step 9a: dataR = max(s, 1e-45),下界保护防止后续 Div 除零 --------
    AscendC::Maxs(dataR, dataR, CONST_S_MIN, n);

    // -------- Step 9b: dataR = min(s, 3.4e34),上界保护防溢出 --------
    AscendC::Mins(dataR, dataR, CONST_S_MAX, n);

    // -------- Step 12: dataT = dataT / dataR = log(u) × ((x-1)+√(x²-1)) / clip(s) --------
    AscendC::Div(dataT, dataT, dataR, n);

    // -------- Step 13c: yFp32 = min(res, log(x)+ln(2)),大参数修正 --------
    AscendC::Min(yFp32, dataT, logTmp, n);
}

} // namespace NsAcosh
#endif // NSACOSH_ACOSH_H