/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* Generated By CANNBot */

/*!
 * \file tan_simt.h
 * \brief tan SIMT kernel implementation
 */

#ifndef TAN_SIMT_H
#define TAN_SIMT_H

#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tan_tiling_data.h"
#include "tan_tiling_key.h"
#include "simt_api/asc_simt.h"
#include "simt_api/math_functions.h"

namespace NsTan {

using namespace AscendC;

// 常量定义
static constexpr float TWO_OVER_PI = 0.6366197723675814f;
static constexpr float PI_OVER_2 = 1.5707963267948966f;
// 3-part Cody-Waite 要求 |fn| ≤ 2^20 ≈ 1e6,取 1e7 留余量;超过此值精度不可控,直接返回 NaN
static constexpr float LARGE_THRESHOLD = 1e7f;

// 3-part Cody-Waite: pi/2 = PIO2_1 + PIO2_2 + PIO2_3
// PIO2_1 = 1.5 (2 sig bits → fn*PIO2_1 exact for |fn| ≤ 2^22)
// PIO2_2 = 9/128 (4 sig bits → fn*PIO2_2 exact for |fn| ≤ 2^20)
static constexpr float PIO2_1 = 1.5f;
static constexpr float PIO2_2 = 0.0703125f;
static constexpr float PIO2_3 = 4.83826794896619e-04f;

// float32 路径的多项式系数(8阶泰勒展开)
static constexpr float C0 = 0.3333333333f;       // 1/3
static constexpr float C1 = 0.1333333333f;       // 2/15
static constexpr float C2 = 0.0539682540f;       // 17/315
static constexpr float C3 = 0.0218253968f;       // 62/2835
static constexpr float C4 = 0.0088632353f;       // 1382/155925
static constexpr float C5 = 0.0035921281f;       // 21844/6081075
static constexpr float C6 = 0.0014558335f;       // 929569/638512875

// Rule 006: LAUNCH_BOUND 按索引位宽模板化
// tan 算子寄存器压力中等(多项式计算变量多),uint32_t 开 1024,uint64_t 开 512
template <typename IDX_T>
static constexpr uint32_t THREAD_NUM = (sizeof(IDX_T) == 4) ? 1024 : 512;

template <typename T, typename IDX_T>
__simt_vf__ __aicore__ LAUNCH_BOUND(THREAD_NUM<IDX_T>)
inline void OpTanSimtKernel(IDX_T totalElements, __gm__ T* input, __gm__ T* output)
{
    for (IDX_T index = static_cast<IDX_T>(AscendC::Simt::GetBlockIdx()) * static_cast<IDX_T>(AscendC::Simt::GetThreadNum()) + static_cast<IDX_T>(AscendC::Simt::GetThreadIdx());
         index < totalElements;
         index += static_cast<IDX_T>(AscendC::Simt::GetThreadNum()) * static_cast<IDX_T>(AscendC::Simt::GetBlockNum()))
    {
        // 读取输入并转为 float32
        float x = static_cast<float>(input[index]);
        
        // 特殊值处理:|x| >= 1e7 或 Inf/NaN → NaN
        float abs_x = fabsf(x);
        bool is_special = (abs_x >= LARGE_THRESHOLD) || isinf(x) || isnan(x);
        float special_result = ASCRT_INF_F / ASCRT_INF_F;  // NaN
        
        // 3-part Cody-Waite range reduction (float-only)
        float fn = roundf(x * TWO_OVER_PI);
        int32_t k = static_cast<int32_t>(fn);
        float t1 = x - fn * PIO2_1;
        float t2 = t1 - fn * PIO2_2;
        float s3 = fn * PIO2_3;
        float e3 = fmaf(fn, PIO2_3, -s3);
        float r = (t2 - s3) - e3;

        // 多项式逼近 tan(r) ≈ r + r³ * P(r²)
        float r2 = r * r;
        float r3 = r2 * r;

        // Horner 方法计算多项式(7阶)
        float p = C6;
        p = fmaf(p, r2, C5);
        p = fmaf(p, r2, C4);
        p = fmaf(p, r2, C3);
        p = fmaf(p, r2, C2);
        p = fmaf(p, r2, C1);
        p = fmaf(p, r2, C0);
        
        float tan_r = r + r3 * p;
        
        // Rule 001: 象限调整使用 select 替代 if-else
        // 奇数象限:tan(x) = -1/tan(r),偶数象限:tan(x) = tan(r)
        int32_t k_odd = k & 1;
        float odd_result = -1.0f / tan_r;
        float normal_result = tan_r;
        float result = (k_odd != 0) ? odd_result : normal_result;
        
        // Rule 001: 最终结果选择(特殊值 vs 正常计算)
        float final_result = is_special ? special_result : result;
        
        output[index] = static_cast<T>(final_result);
    }
}

template <typename T>
__aicore__ inline void Process(GM_ADDR input, GM_ADDR output, const TanTilingData* tilingData)
{
    int64_t totalElements = tilingData->totalElements;
    __gm__ T* input_gm = (__gm__ T*)input;
    __gm__ T* output_gm = (__gm__ T*)output;

    if (totalElements <= static_cast<int64_t>(INT32_MAX)) {
        AscendC::Simt::VF_CALL<OpTanSimtKernel<T, int32_t>>(
            AscendC::Simt::Dim3(THREAD_NUM<int32_t>),
            static_cast<int32_t>(totalElements), input_gm, output_gm);
    } else {
        AscendC::Simt::VF_CALL<OpTanSimtKernel<T, int64_t>>(
            AscendC::Simt::Dim3(THREAD_NUM<int64_t>),
            totalElements, input_gm, output_gm);
    }
}

} // namespace NsTan

#endif // TAN_SIMT_H