/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* Generated By CANNBot */

/*!
 * \file cosh.h
 * \brief Cosh 算子 kernel 类定义(标准 Ascend C kernel,arch35 / Ascend950PR,DESIGN §3.4)
 *
 * 数学契约: y = cosh(x) = (e^x + e^(-x)) / 2 (spec math_semantics.formula: y = np.cosh(x))
 *
 * NPU 实现采用 exp 分解数值流程(REQUIREMENTS §4.2,与数学契约等价,全 fp32 计算):
 *   Step1: ax = abs(x)         // cosh 为偶函数,取绝对值降低大值溢出风险
 *   Step2: x1 = ax * 0.5       // x/2
 *   Step3: x2 = ax * (-1.5)    // -1.5x
 *   Step4: e1 = exp(x1)        // exp(x/2)
 *   Step5: e2 = exp(x2)        // exp(-1.5x)
 *   Step6: x3 = e1 + e2        // exp(x/2) + exp(-1.5x)
 *   Step7: x4 = x3 * 0.5       // 0.5 * (...)
 *   Step8: y  = x4 * e1        // 0.5*(exp(x/2)+exp(-1.5x))*exp(x/2) = cosh(x)
 *
 * 模板参数 T:输入/输出 dtype。
 *   - half / bfloat16_t:WithCast 路径,入口 Cast(CAST_NONE) 升 fp32、出口 Cast(CAST_RINT) 还原。
 *   - float:WithoutCast 路径,fp32 直通,省去两次 Cast。
 *   分支由 needCast = is_same_v<T,half> || is_same_v<T,bfloat16_t> 编译期裁剪。
 *
 * 三级流水(CopyIn -> Compute -> CopyOut)+ double buffer(BUFFER_NUM=2)。
 * 8 步全 fp32,使用 3 份独立 fp32 VECCALC 缓冲(ax / e1 / work),无原位别名。
 * DataCopyPad 处理非 32B 对齐尾块。
 */
#ifndef OPS_COSH_OP_KERNEL_ARCH35_COSH_H_
#define OPS_COSH_OP_KERNEL_ARCH35_COSH_H_

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "cosh_tiling_data.h"
#include "cosh_tiling_key.h"

namespace NsCosh {

using namespace AscendC;

template <typename T>
class KernelCosh {
    // fp16 / bf16 需升精度到 fp32 计算;fp32 直通
    static constexpr bool needCast = std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t>;
    static constexpr int32_t BUFFER_NUM = 2;  // double buffer

public:
    __aicore__ inline KernelCosh() {}

    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, const CoshTilingData* tilingData);
    __aicore__ inline void Process();

private:
    __aicore__ inline void CopyIn(int64_t progress, int64_t currentNum);
    __aicore__ inline void Compute(int64_t currentNum);
    __aicore__ inline void CopyOut(int64_t progress, int64_t currentNum);

private:
    TPipe pipe_;
    TQue<QuePosition::VECIN, BUFFER_NUM> inQueue_;
    TQue<QuePosition::VECOUT, BUFFER_NUM> outQueue_;
    // fp32 中间计算缓冲(VECCALC,单份,不参与队列流水)
    TBuf<QuePosition::VECCALC> axBuf_;    // ax = abs(x) 升精度结果
    TBuf<QuePosition::VECCALC> e1Buf_;    // x1 -> e1 = exp(x/2),保活到 Step8
    TBuf<QuePosition::VECCALC> workBuf_;  // x2 -> e2 -> x3 -> x4 -> y

    GlobalTensor<T> xGm_;
    GlobalTensor<T> yGm_;

    int64_t blockLength_ = 0;  // 本核负责的元素数
    int64_t ubFactor_ = 0;     // 单次 UB 循环处理元素数
};

template <typename T>
__aicore__ inline void KernelCosh<T>::Init(GM_ADDR x, GM_ADDR y, const CoshTilingData* tilingData)
{
    int64_t remainderLength = tilingData->totalNum - tilingData->blockFactor * AscendC::GetBlockIdx();
    blockLength_ = (remainderLength > tilingData->blockFactor) ? tilingData->blockFactor : remainderLength;
    ubFactor_ = tilingData->ubFactor;

    xGm_.SetGlobalBuffer((__gm__ T*)x + tilingData->blockFactor * AscendC::GetBlockIdx(), blockLength_);
    yGm_.SetGlobalBuffer((__gm__ T*)y + tilingData->blockFactor * AscendC::GetBlockIdx(), blockLength_);

    // 输入 / 输出队列:原 dtype T,双缓冲
    pipe_.InitBuffer(inQueue_, BUFFER_NUM, ubFactor_ * sizeof(T));
    pipe_.InitBuffer(outQueue_, BUFFER_NUM, ubFactor_ * sizeof(T));
    // fp32 中间缓冲:3 份独立单缓冲,无别名复用
    pipe_.InitBuffer(axBuf_, ubFactor_ * sizeof(float));
    pipe_.InitBuffer(e1Buf_, ubFactor_ * sizeof(float));
    pipe_.InitBuffer(workBuf_, ubFactor_ * sizeof(float));
}

template <typename T>
__aicore__ inline void KernelCosh<T>::CopyIn(int64_t progress, int64_t currentNum)
{
    AscendC::LocalTensor<T> xLocal = inQueue_.template AllocTensor<T>();
    AscendC::DataCopyExtParams copyParams;
    copyParams.blockCount = 1;
    copyParams.blockLen = static_cast<uint32_t>(currentNum * sizeof(T));
    copyParams.srcStride = 0;
    copyParams.dstStride = 0;
    copyParams.rsv = 0;
    AscendC::DataCopyPadExtParams<T> padParams{false, 0, 0, 0};
    AscendC::DataCopyPad(xLocal, xGm_[progress * ubFactor_], copyParams, padParams);
    inQueue_.EnQue(xLocal);
}

template <typename T>
__aicore__ inline void KernelCosh<T>::Compute(int64_t currentNum)
{
    AscendC::LocalTensor<T> xLocal = inQueue_.template DeQue<T>();
    AscendC::LocalTensor<T> yLocal = outQueue_.template AllocTensor<T>();

    AscendC::LocalTensor<float> ax = axBuf_.Get<float>();
    AscendC::LocalTensor<float> e1 = e1Buf_.Get<float>();
    AscendC::LocalTensor<float> work = workBuf_.Get<float>();

    const int32_t cnt = static_cast<int32_t>(currentNum);

    if constexpr (needCast) {
        // Step0:升精度 T -> fp32(无精度损失)
        AscendC::Cast(ax, xLocal, AscendC::RoundMode::CAST_NONE, cnt);
        // Step1:ax = abs(ax)(原位)
        AscendC::Abs(ax, ax, cnt);
    } else {
        // fp32 直通:xLocal 即 float,直接 Abs 到 ax 缓冲
        AscendC::Abs(ax, xLocal, cnt);  // Step1
    }

    AscendC::Muls(e1, ax, static_cast<float>(0.5), cnt);    // Step2 x1 = ax*0.5  -> e1Buf
    AscendC::Muls(work, ax, static_cast<float>(-1.5), cnt); // Step3 x2 = ax*(-1.5) -> workBuf
    AscendC::Exp(e1, e1, cnt);                              // Step4 e1 = exp(x/2)(原位 e1Buf)
    AscendC::Exp(work, work, cnt);                          // Step5 e2 = exp(-1.5x)(原位 workBuf)
    AscendC::Add(work, e1, work, cnt);                      // Step6 x3 = e1 + e2 -> workBuf
    AscendC::Muls(work, work, static_cast<float>(0.5), cnt);// Step7 x4 = 0.5*x3(原位 workBuf)
    AscendC::Mul(work, work, e1, cnt);                      // Step8 y = x4*e1 -> workBuf

    if constexpr (needCast) {
        // 还原 fp32 -> T(就近偶舍入)
        AscendC::Cast(yLocal, work, AscendC::RoundMode::CAST_RINT, cnt);
    } else {
        // fp32 直通:work 即结果,按元素级 Adds(+0) 拷贝到 yLocal(float)。
        // ⚠️ 不能用 DataCopy(yLocal, work, cnt) —— 本地 UB→UB DataCopy 要求 32B 对齐
        // (fp32 8 元素),非对齐 cnt(rank0 标量/17/45/130 等)会导致尾部数据丢失。
        // Adds 是 vector 元素级操作,IEEE754 规定 x+0=x(NaN/±Inf 传播正确),
        // 与 Cast 路径在精度行为上对齐。
        AscendC::Adds(yLocal, work, static_cast<float>(0.0f), cnt);
    }

    outQueue_.template EnQue<T>(yLocal);
    inQueue_.FreeTensor(xLocal);
}

template <typename T>
__aicore__ inline void KernelCosh<T>::CopyOut(int64_t progress, int64_t currentNum)
{
    AscendC::LocalTensor<T> yLocal = outQueue_.template DeQue<T>();
    AscendC::DataCopyExtParams copyParams;
    copyParams.blockCount = 1;
    copyParams.blockLen = static_cast<uint32_t>(currentNum * sizeof(T));
    copyParams.srcStride = 0;
    copyParams.dstStride = 0;
    copyParams.rsv = 0;
    AscendC::DataCopyPad(yGm_[progress * ubFactor_], yLocal, copyParams);
    outQueue_.FreeTensor(yLocal);
}

template <typename T>
__aicore__ inline void KernelCosh<T>::Process()
{
    int64_t loopCount = (blockLength_ + ubFactor_ - 1) / ubFactor_;
    for (int64_t i = 0; i < loopCount; i++) {
        int64_t currentNum = (i == (loopCount - 1)) ? (blockLength_ - ubFactor_ * i) : ubFactor_;
        CopyIn(i, currentNum);
        Compute(currentNum);
        CopyOut(i, currentNum);
    }
}

} // namespace NsCosh
#endif // OPS_COSH_OP_KERNEL_ARCH35_COSH_H_