* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tan_simt.h
* \brief tan SIMT kernel implementation
*/
#ifndef TAN_SIMT_H
#define TAN_SIMT_H
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tan_tiling_data.h"
#include "tan_tiling_key.h"
#include "simt_api/asc_simt.h"
#include "simt_api/math_functions.h"
namespace NsTan {
using namespace AscendC;
static constexpr float TWO_OVER_PI = 0.6366197723675814f;
static constexpr float PI_OVER_2 = 1.5707963267948966f;
static constexpr float LARGE_THRESHOLD = 1e7f;
static constexpr float PIO2_1 = 1.5f;
static constexpr float PIO2_2 = 0.0703125f;
static constexpr float PIO2_3 = 4.83826794896619e-04f;
static constexpr float C0 = 0.3333333333f;
static constexpr float C1 = 0.1333333333f;
static constexpr float C2 = 0.0539682540f;
static constexpr float C3 = 0.0218253968f;
static constexpr float C4 = 0.0088632353f;
static constexpr float C5 = 0.0035921281f;
static constexpr float C6 = 0.0014558335f;
template <typename IDX_T>
static constexpr uint32_t THREAD_NUM = (sizeof(IDX_T) == 4) ? 1024 : 512;
template <typename T, typename IDX_T>
__simt_vf__ __aicore__ LAUNCH_BOUND(THREAD_NUM<IDX_T>)
inline void OpTanSimtKernel(IDX_T totalElements, __gm__ T* input, __gm__ T* output)
{
for (IDX_T index = static_cast<IDX_T>(AscendC::Simt::GetBlockIdx()) * static_cast<IDX_T>(AscendC::Simt::GetThreadNum()) + static_cast<IDX_T>(AscendC::Simt::GetThreadIdx());
index < totalElements;
index += static_cast<IDX_T>(AscendC::Simt::GetThreadNum()) * static_cast<IDX_T>(AscendC::Simt::GetBlockNum()))
{
float x = static_cast<float>(input[index]);
float abs_x = fabsf(x);
bool is_special = (abs_x >= LARGE_THRESHOLD) || isinf(x) || isnan(x);
float special_result = ASCRT_INF_F / ASCRT_INF_F;
float fn = roundf(x * TWO_OVER_PI);
int32_t k = static_cast<int32_t>(fn);
float t1 = x - fn * PIO2_1;
float t2 = t1 - fn * PIO2_2;
float s3 = fn * PIO2_3;
float e3 = fmaf(fn, PIO2_3, -s3);
float r = (t2 - s3) - e3;
float r2 = r * r;
float r3 = r2 * r;
float p = C6;
p = fmaf(p, r2, C5);
p = fmaf(p, r2, C4);
p = fmaf(p, r2, C3);
p = fmaf(p, r2, C2);
p = fmaf(p, r2, C1);
p = fmaf(p, r2, C0);
float tan_r = r + r3 * p;
int32_t k_odd = k & 1;
float odd_result = -1.0f / tan_r;
float normal_result = tan_r;
float result = (k_odd != 0) ? odd_result : normal_result;
float final_result = is_special ? special_result : result;
output[index] = static_cast<T>(final_result);
}
}
template <typename T>
__aicore__ inline void Process(GM_ADDR input, GM_ADDR output, const TanTilingData* tilingData)
{
int64_t totalElements = tilingData->totalElements;
__gm__ T* input_gm = (__gm__ T*)input;
__gm__ T* output_gm = (__gm__ T*)output;
if (totalElements <= static_cast<int64_t>(INT32_MAX)) {
AscendC::Simt::VF_CALL<OpTanSimtKernel<T, int32_t>>(
AscendC::Simt::Dim3(THREAD_NUM<int32_t>),
static_cast<int32_t>(totalElements), input_gm, output_gm);
} else {
AscendC::Simt::VF_CALL<OpTanSimtKernel<T, int64_t>>(
AscendC::Simt::Dim3(THREAD_NUM<int64_t>),
totalElements, input_gm, output_gm);
}
}
}
#endif