/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/**
 * NOTE: Portions of this code were AI-generated and have been
 * technically reviewed for functional accuracy and security
 */

/*!
 * \file asinh_grad.h
 * \brief AsinhGrad kernel class definition (arch35)
 *
 * Computes: z = 2 * dy * exp(y) / (exp(2y) + 1)
 * Equivalent to: z = dy / cosh(y)
 *
 * Template parameters:
 *   - T: data type (float / half / bfloat16_t)
 *   - BUFFER_MODE: 0=single buffer, 1=double buffer
 *
 * Iteration 3: FP32 + FP16 + BF16 (upgrade to FP32 compute, cast back)
 */
#ifndef ASINH_GRAD_H
#define ASINH_GRAD_H

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "asinh_grad_tiling_data.h"
#include "asinh_grad_tiling_key.h"

namespace NsAsinhGrad {

using namespace AscendC;

template <typename T, int BUFFER_MODE>
class AsinhGrad {
    static constexpr int32_t BUFFER_NUM = BUFFER_MODE ? 2 : 1;
    static constexpr bool NEED_CAST = !std::is_same_v<T, float>;

public:
    __aicore__ inline AsinhGrad() {}

    __aicore__ inline void Init(GM_ADDR y, GM_ADDR dy, GM_ADDR z,
                                const AsinhGradTilingData* tilingData);
    __aicore__ inline void Process();

private:
    __aicore__ inline void CopyIn(int64_t progress, int64_t currentNum);
    __aicore__ inline void Compute(int64_t currentNum);
    __aicore__ inline void CopyOut(int64_t progress, int64_t currentNum);

private:
    TPipe pipe;
    TQue<QuePosition::VECIN, BUFFER_NUM> inQueY;
    TQue<QuePosition::VECIN, BUFFER_NUM> inQueDy;
    TQue<QuePosition::VECOUT, BUFFER_NUM> outQueZ;

    // FP16/BF16 path: FP32 intermediate buffers for cast-up computation (VECCALC)
    TBuf<QuePosition::VECCALC> tmpYFp32_;
    TBuf<QuePosition::VECCALC> tmpDyFp32_;
    TBuf<QuePosition::VECCALC> tmpExpY_;
    TBuf<QuePosition::VECCALC> tmpExp2Y_;
    TBuf<QuePosition::VECCALC> tmpZFp32_;

    GlobalTensor<T> gmY_;
    GlobalTensor<T> gmDy_;
    GlobalTensor<T> gmZ_;
    int64_t blockLength_ = 0;
    int64_t ubLength_ = 0;
};

template <typename T, int BUFFER_MODE>
__aicore__ inline void AsinhGrad<T, BUFFER_MODE>::Init(
    GM_ADDR y, GM_ADDR dy, GM_ADDR z,
    const AsinhGradTilingData* tilingData)
{
    int64_t remainder = tilingData->totalNum - tilingData->blockFactor * GetBlockIdx();
    blockLength_ = (remainder > tilingData->blockFactor) ? tilingData->blockFactor : remainder;
    ubLength_ = tilingData->ubFactor;

    int64_t off = tilingData->blockFactor * GetBlockIdx();
    gmY_.SetGlobalBuffer((__gm__ T*)y + off, blockLength_);
    gmDy_.SetGlobalBuffer((__gm__ T*)dy + off, blockLength_);
    gmZ_.SetGlobalBuffer((__gm__ T*)z + off, blockLength_);

    pipe.InitBuffer(inQueY, BUFFER_NUM, ubLength_ * sizeof(T));
    pipe.InitBuffer(inQueDy, BUFFER_NUM, ubLength_ * sizeof(T));
    pipe.InitBuffer(outQueZ, BUFFER_NUM, ubLength_ * sizeof(T));

    if constexpr (NEED_CAST) {
        // FP16/BF16 path: need 5 FP32 intermediate buffers for cast-up computation
        pipe.InitBuffer(tmpYFp32_, ubLength_ * sizeof(float));
        pipe.InitBuffer(tmpDyFp32_, ubLength_ * sizeof(float));
        pipe.InitBuffer(tmpExpY_, ubLength_ * sizeof(float));
        pipe.InitBuffer(tmpExp2Y_, ubLength_ * sizeof(float));
        pipe.InitBuffer(tmpZFp32_, ubLength_ * sizeof(float));
    } else {
        // FP32 path: only 2 temporary buffers (expY, exp2Y); zL from outQueZ is reused
        pipe.InitBuffer(tmpExpY_, ubLength_ * sizeof(float));
        pipe.InitBuffer(tmpExp2Y_, ubLength_ * sizeof(float));
    }
}

template <typename T, int BUFFER_MODE>
__aicore__ inline void AsinhGrad<T, BUFFER_MODE>::CopyIn(
    int64_t progress, int64_t currentNum)
{
    LocalTensor<T> yL = inQueY.template AllocTensor<T>();
    LocalTensor<T> dyL = inQueDy.template AllocTensor<T>();

    DataCopyParams cp;
    cp.blockCount = 1;
    cp.blockLen = currentNum * sizeof(T);
    cp.srcStride = 0;
    cp.dstStride = 0;

    DataCopyPad(yL, gmY_[progress * ubLength_], cp, {false, 0, 0, 0});
    DataCopyPad(dyL, gmDy_[progress * ubLength_], cp, {false, 0, 0, 0});

    inQueY.EnQue(yL);
    inQueDy.EnQue(dyL);
}

template <typename T, int BUFFER_MODE>
__aicore__ inline void AsinhGrad<T, BUFFER_MODE>::Compute(int64_t currentNum)
{
    LocalTensor<T> yL = inQueY.template DeQue<T>();
    LocalTensor<T> dyL = inQueDy.template DeQue<T>();
    LocalTensor<T> zL = outQueZ.template AllocTensor<T>();

    if constexpr (NEED_CAST) {
        // FP16/BF16 path: cast up to FP32, compute, cast back
        LocalTensor<float> yFp32 = tmpYFp32_.Get<float>();
        LocalTensor<float> dyFp32 = tmpDyFp32_.Get<float>();
        LocalTensor<float> expY = tmpExpY_.Get<float>();
        LocalTensor<float> exp2Y = tmpExp2Y_.Get<float>();
        LocalTensor<float> zFp32 = tmpZFp32_.Get<float>();

        // Step 1: Cast input dtype -> FP32
        Cast(yFp32, yL, RoundMode::CAST_NONE, currentNum);
        Cast(dyFp32, dyL, RoundMode::CAST_NONE, currentNum);

        // Step 2: z = 2 * dy * exp(y) / (exp(2y) + 1) in FP32 domain
        Exp(expY, yFp32, currentNum);                    // expY = e^y
        Mul(exp2Y, expY, expY, currentNum);              // exp2Y = e^{2y} = (e^y)^2
        Adds(exp2Y, exp2Y, 1.0f, currentNum);            // exp2Y = e^{2y} + 1
        Mul(zFp32, dyFp32, expY, currentNum);            // zFp32 = dy * e^y
        Div(zFp32, zFp32, exp2Y, currentNum);            // zFp32 = (dy * e^y) / (e^{2y} + 1)
        Muls(zFp32, zFp32, 2.0f, currentNum);            // zFp32 = 2 * (...)

        // Step 3: Cast FP32 -> target dtype (CAST_RINT = banker's rounding)
        Cast(zL, zFp32, RoundMode::CAST_RINT, currentNum);
    } else {
        // FP32 main path: compute directly without cast
        LocalTensor<float> expY = tmpExpY_.Get<float>();
        LocalTensor<float> exp2Y = tmpExp2Y_.Get<float>();

        Exp(expY, yL, currentNum);                       // expY = e^y
        Mul(exp2Y, expY, expY, currentNum);              // exp2Y = e^{2y} = (e^y)^2
        Adds(exp2Y, exp2Y, 1.0f, currentNum);            // exp2Y = e^{2y} + 1
        Mul(zL, dyL, expY, currentNum);                  // zL = dy * e^y
        Div(zL, zL, exp2Y, currentNum);                  // zL = (dy * e^y) / (e^{2y} + 1)
        Muls(zL, zL, 2.0f, currentNum);                  // zL = 2 * (...)
    }

    outQueZ.template EnQue<T>(zL);
    inQueY.FreeTensor(yL);
    inQueDy.FreeTensor(dyL);
}

template <typename T, int BUFFER_MODE>
__aicore__ inline void AsinhGrad<T, BUFFER_MODE>::CopyOut(
    int64_t progress, int64_t currentNum)
{
    LocalTensor<T> zL = outQueZ.template DeQue<T>();

    DataCopyParams cp;
    cp.blockCount = 1;
    cp.blockLen = currentNum * sizeof(T);
    cp.srcStride = 0;
    cp.dstStride = 0;

    DataCopyPad(gmZ_[progress * ubLength_], zL, cp);
    outQueZ.FreeTensor(zL);
}

template <typename T, int BUFFER_MODE>
__aicore__ inline void AsinhGrad<T, BUFFER_MODE>::Process()
{
    int64_t loopCount = (blockLength_ + ubLength_ - 1) / ubLength_;
    for (int64_t i = 0; i < loopCount; i++) {
        int64_t curN = (i == loopCount - 1) ? (blockLength_ - ubLength_ * i) : ubLength_;
        CopyIn(i, curN);
        Compute(curN);
        CopyOut(i, curN);
    }
}

} // namespace NsAsinhGrad
#endif // ASINH_GRAD_H