/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
/* Generated By CANNBot */

/*!
 * \file sinh_dag.h
 * \brief Sinh DAG definition + SinhCustomVF (PTX dual-branch: Taylor + offset-exp)
 *
 * Algorithm (PTX dual-branch strategy):
 *   Branch 1 (|x| < 1):  Taylor polynomial  sinh(x) = x + x^3 * P(x^2)
 *   Branch 2 (|x| >= 1): offset-exp         t = exp(|x| - ln2), sinh = t - 0.25/t
 *   Branch 3 (|x| >= 90): overflow guard     return sign(x) * inf
 */

#ifndef SINH_DAG_H
#define SINH_DAG_H

// Host compilation: mock __aicore__ (Kernel compiler has built-in definition)
#ifndef __CCE_AICORE__
#ifndef __aicore__
#define __aicore__
#endif
#endif

#include "atvoss/util/dag.h"
#include "atvoss/util/vec.h"
#include "atvoss/util/placeholder.h"

using namespace Ops::Base;

namespace NsSinh {

// =============================================================================
// SinhCustomVF: PTX dual-branch sinh computation (MicroAPI Tier 3)
//
// Always computes in float32. Cast up/down is handled by DAG layer.
// =============================================================================
template <typename T>
struct SinhCustomVF : public Vec::ElemwiseUnaryOP<T, T> {
    __aicore__ inline SinhCustomVF(LocalTensor<T>& dst, LocalTensor<T>& src, uint32_t count)
    {
#ifdef __CCE_AICORE__
        // Taylor coefficients (matching PTX hex constants)
        constexpr T c2 = T(0.16666667f);         // 0x3E2AAAAB, 1/6
        constexpr T c3 = T(0.008333347f);         // 0x3C08889A, 1/120
        constexpr T c4 = T(0.00019841270f);       // 0x394FFF49, 1/5040
        constexpr T c5 = T(0.0000027557319f);     // 0x363D0ADA, 1/362880

        constexpr T one_val = T(1.0f);
        constexpr T ninety_val = T(90.0f);
        constexpr T zero_val = T(0.0f);
        constexpr T quarter_val = T(0.25f);
        constexpr T neg_ln2 = T(-0.6931471805599453f);
        const T inf_val = T(__builtin_huge_valf());  // +infinity

        uint32_t VL = AscendC::VECTOR_REG_WIDTH / sizeof(T);
        uint16_t loopNum = CeilDivision(count, VL);

        static constexpr AscendC::MicroAPI::DivSpecificMode divMode = {
            AscendC::MicroAPI::MaskMergeMode::ZEROING, true  // high precision
        };

        __VEC_SCOPE__
        {
            __ubuf__ T* srcAddr = (__ubuf__ T*)src.GetPhyAddr();
            __ubuf__ T* dstAddr = (__ubuf__ T*)dst.GetPhyAddr();

            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_x;
            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_abs;
            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_x2;
            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_poly;
            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_temp;
            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_exp;
            AscendC::MicroAPI::RegTensor<T, AscendC::MicroAPI::RegTraitNumOne> reg_result;
            AscendC::MicroAPI::MaskReg mask;
            AscendC::MicroAPI::MaskReg cmpMask0;
            AscendC::MicroAPI::MaskReg cmpMask1;

            for (uint16_t i = 0; i < loopNum; i++) {
                mask = AscendC::MicroAPI::UpdateMask<T, AscendC::MicroAPI::RegTraitNumOne>(count);
                AscendC::MicroAPI::DataCopy(reg_x, (__ubuf__ T*)(srcAddr + i * VL));

                // Step 0: abs_x = |x|
                AscendC::MicroAPI::Abs(reg_abs, reg_x, mask);

                // === Taylor branch: sinh(x) = x + x^3 * P(x^2) ===
                // x2 = x * x
                AscendC::MicroAPI::Mul(reg_x2, reg_x, reg_x, mask);

                // Horner: P(x2) = ((c5 * x2 + c4) * x2 + c3) * x2 + c2
                AscendC::MicroAPI::Duplicate(reg_poly, c5);
                AscendC::MicroAPI::Mul(reg_poly, reg_poly, reg_x2, mask);     // c5*x2
                AscendC::MicroAPI::Adds(reg_poly, reg_poly, c4, mask);        // + c4
                AscendC::MicroAPI::Mul(reg_poly, reg_poly, reg_x2, mask);     // * x2
                AscendC::MicroAPI::Adds(reg_poly, reg_poly, c3, mask);        // + c3
                AscendC::MicroAPI::Mul(reg_poly, reg_poly, reg_x2, mask);     // * x2
                AscendC::MicroAPI::Adds(reg_poly, reg_poly, c2, mask);        // + c2

                // x3 = x2 * x (reuse reg_temp)
                AscendC::MicroAPI::Mul(reg_temp, reg_x2, reg_x, mask);
                // taylor = x + x3 * poly
                AscendC::MicroAPI::Mul(reg_poly, reg_poly, reg_temp, mask);   // x3 * P(x2)
                AscendC::MicroAPI::Add(reg_poly, reg_x, reg_poly, mask);      // x + x3*P(x2)
                // reg_poly now holds Taylor result

                // === Exp branch: offset-exp to avoid overflow ===
                // shifted = |x| - ln2
                AscendC::MicroAPI::Adds(reg_temp, reg_abs, neg_ln2, mask);
                // t = e^(|x| - ln2) = e^|x| / 2
                AscendC::MicroAPI::Exp(reg_exp, reg_temp, mask);
                // inv_t = 0.25 / t = e^(-|x|) / 2
                AscendC::MicroAPI::Duplicate(reg_temp, quarter_val);
                AscendC::MicroAPI::Div<T, &divMode>(reg_temp, reg_temp, reg_exp, mask);
                // exp_result = t - 0.25/t = e^|x|/2 - e^(-|x|)/2 = sinh(|x|)
                AscendC::MicroAPI::Sub(reg_exp, reg_exp, reg_temp, mask);
                // reg_exp now holds Exp result (positive, = sinh(|x|))

                // === Sign recovery for exp branch: x < 0 ? -exp_result : exp_result ===
                // Must apply BEFORE branch selection because Taylor already has correct sign
                AscendC::MicroAPI::CompareScalar<T, CMPMODE::LT>(cmpMask0, reg_x, zero_val, mask);
                AscendC::MicroAPI::Neg(reg_temp, reg_exp, mask);
                AscendC::MicroAPI::Select(reg_exp, reg_temp, reg_exp, cmpMask0);
                // reg_exp now holds signed exp result: sign(x) * sinh(|x|)

                // === Branch selection: |x| < 1 ? Taylor : Exp ===
                AscendC::MicroAPI::CompareScalar<T, CMPMODE::LT>(cmpMask1, reg_abs, one_val, mask);
                AscendC::MicroAPI::Select(reg_result, reg_poly, reg_exp, cmpMask1);

                // === Overflow guard: |x| >= 90 ? sign(x)*inf : result ===
                AscendC::MicroAPI::CompareScalar<T, CMPMODE::GE>(cmpMask1, reg_abs, ninety_val, mask);
                // Create signed infinity: +inf for x >= 0, -inf for x < 0
                AscendC::MicroAPI::Duplicate(reg_temp, inf_val);          // +inf
                AscendC::MicroAPI::Neg(reg_exp, reg_temp, mask);          // -inf (reuse reg_exp)
                AscendC::MicroAPI::Select(reg_temp, reg_exp, reg_temp, cmpMask0); // sign(x)*inf
                // Apply overflow: |x| >= 90 ? sign(x)*inf : branch_result
                AscendC::MicroAPI::Select(reg_result, reg_temp, reg_result, cmpMask1);

                // Store output
                AscendC::MicroAPI::DataCopy((__ubuf__ T*)(dstAddr + i * VL), reg_result, mask);
            }
        }
#endif
    }
};

// =============================================================================
// DAG variant 1: float32 direct computation (no Cast)
// =============================================================================
template <typename T>
struct SinhWithoutCast {
    using OpCopyIn = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
    using OpSinh = Bind<SinhCustomVF<float>, OpCopyIn>;
    using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, OpSinh>;

    using Outputs = Elems<OpCopyOut>;
    using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
    using OpDag = DAGSch<Outputs, void, MemCfg>;
};

// =============================================================================
// DAG variant 2: half/bfloat16 with Cast promotion to float32
// =============================================================================
template <typename T>
struct SinhWithCast {
    using OpCopyIn = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
    using CastIn = Bind<Vec::Cast<float, T, 0>, OpCopyIn>;        // T -> float32
    using OpSinh = Bind<SinhCustomVF<float>, CastIn>;
    using CastOut = Bind<Vec::Cast<T, float, 1>, OpSinh>;         // float32 -> T
    using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, CastOut>;

    using Outputs = Elems<OpCopyOut>;
    using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
    using OpDag = DAGSch<Outputs, void, MemCfg>;
};

} // namespace NsSinh

#endif  // SINH_DAG_H