/**

 * Copyright (c) 2026 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */



/*!

 * \file atan.h

 * \brief

 */



#ifndef TILEOP_TILE_OPERATOR_ATAN__H

#define TILEOP_TILE_OPERATOR_ATAN__H

#ifdef __DAV_V220

#define ATAN_SYNC_V pipe_barrier(PIPE_V)

#else

#define ATAN_SYNC_V

#endif



#include "pto_tile.h"

#include "utils/layout.h"

#include "utils/tile_tensor.h"



#include <cmath>



template <typename DATA, typename CMP>

TILEOP void AtanCalc(DATA dst, DATA src, DATA tmp1, DATA tmp2, CMP cmp)

{

    constexpr float a[] = {-0.333329409, 0.199887753, -0.141718030, 0.105184801, -0.0725297481, 0.0398497507, -0.0143969795, 0.00245002890};

    constexpr float pi2 = 1.570796326794896619;

    pto::TABS(tmp1, src);

    pto::TEXPANDS(dst, 1.0);

    ATAN_SYNC_V;

    pto::TDIV(tmp2, dst, tmp1);

    pto::TCMPS(cmp, tmp1, 1.0, pto::CmpMode::GT);

    ATAN_SYNC_V;

    pto::TSEL(tmp2, cmp, tmp2, tmp1, dst);

    ATAN_SYNC_V;

    pto::TMUL(tmp1, tmp2, tmp2);

    ATAN_SYNC_V;

    pto::TMULS(dst, tmp1, a[7]);

    ATAN_SYNC_V;

    pto::TADDS(dst, dst, a[6]);

    ATAN_SYNC_V;

    for (int i = 5; i >= 0; --i) {

        pto::TMUL(dst, dst, tmp1);

        ATAN_SYNC_V;

        pto::TADDS(dst, dst, a[i]);

        ATAN_SYNC_V;

    }

    pto::TMUL(dst, dst, tmp1);

    ATAN_SYNC_V;

    pto::TMUL(dst, dst, tmp2);

    ATAN_SYNC_V;

    pto::TADD(dst, dst, tmp2);

    ATAN_SYNC_V;

    pto::TNEG(tmp1, dst);

    ATAN_SYNC_V;

    pto::TADDS(tmp1, tmp1, pi2);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, tmp1, dst, tmp2);

    ATAN_SYNC_V;

    pto::TNEG(tmp1, dst);

    pto::TCMPS(cmp, src, 0.0, pto::CmpMode::GE);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, dst, tmp1, tmp2);

    ATAN_SYNC_V;

}



template <typename DST>

TILEOP void AtanGetShape(DST dst, size_t dstShape[])

{

    const auto dstLayout = dst.GetLayout();

    dstShape[DIM_1ST] = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();

    dstShape[DIM_2ND] = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();

    dstShape[DIM_3RD] = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();

    dstShape[DIM_4TH] = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();

    dstShape[DIM_5TH] = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();

}



#define OP_TILE_OP_ATAN TAtan

template <typename DST, typename TMP, typename SRC>

TILEOP void TAtan(DST dst, TMP tmp, SRC src)

{

    constexpr int64_t NUM_3 = 3;

    constexpr int64_t NUM_8 = 8;

    constexpr int64_t NUM_32 = 32;

    size_t dstShape[MAX_DIMS];

    AtanGetShape(dst, dstShape);

    constexpr auto tileH = TileOp::GetTensorTileShapeDim<DST, DIM_4TH, MAX_DIMS>();

    constexpr auto tileW = TileOp::GetTensorTileShapeDim<DST, DIM_5TH, MAX_DIMS>();

    constexpr auto cmpTileW = ((tileW + NUM_8 - 1) / NUM_8 + NUM_32 - 1) / NUM_32 * NUM_32;

    auto cmpSize = (dstShape[DIM_5TH] + NUM_8 - 1) / NUM_8;

    using CmpTileDefine = pto::Tile<pto::TileType::Vec, uint8_t, 1, cmpTileW, pto::BLayout::RowMajor, -1, -1>;

    auto dstTile = PtoTile<DST>(dst);

    auto srcTile = PtoTile<SRC>(src);

    auto tmp1Tile = PtoTile<DST>(dst);

    auto tmp2Tile = PtoTile<DST>(dst);

    CmpTileDefine cmpTile(dstShape[DIM_4TH], cmpSize);

    for (LoopVar n0Index = 0; n0Index < dstShape[DIM_1ST]; ++n0Index) {

        for (LoopVar n1Index = 0; n1Index < dstShape[DIM_2ND]; ++n1Index) {

            for (LoopVar n2Index = 0; n2Index < dstShape[DIM_3RD]; ++n2Index) {

                auto dstOffset = TileOffset(n0Index, n1Index, n2Index);

                dstTile.Assign(dst, dstOffset);

                srcTile.Assign(src, dstOffset);

                auto tmp1Offset = GenTileOffset(dst, dstOffset) * NUM_3;

                auto tmp2Offset = tmp1Offset + tileH * tileW;

                auto cmpOffset = tmp2Offset + tileH * tileW;

                tmp1Tile.Assign(tmp.GetAddr(), tmp1Offset);

                tmp2Tile.Assign(tmp.GetAddr(), tmp2Offset);

                pto::TASSIGN(cmpTile, tmp.GetAddr() + cmpOffset * sizeof(typename DST::Type));

                AtanCalc(dstTile.Data(), srcTile.Data(), tmp1Tile.Data(), tmp2Tile.Data(), cmpTile);

            }

        }

    }

}



template <typename HDST, typename FSRC, typename UDST, typename UTMP, typename CMP>

TILEOP void Atan2Cast(HDST dstH, FSRC srcF, UDST dstU, UTMP tmpU, CMP cmp)

{

    constexpr uint16_t sign = 0x8000u;

    constexpr uint16_t val = 0x4000u;

    pto::TCVT(dstH, srcF, pto::RoundMode::CAST_NONE);

    ATAN_SYNC_V;

    pto::TANDS(tmpU, dstU, sign);

    ATAN_SYNC_V;

    pto::TORS(dstU, tmpU, val);

    ATAN_SYNC_V;

    pto::TCMPS(cmp, dstH, 0.0, pto::CmpMode::GE);

    ATAN_SYNC_V;

}



template <typename DATA, typename CMP>

TILEOP void Atan2Sp(DATA dst, DATA src0, DATA src1, DATA tmp1, DATA tmp2, DATA tmp3, CMP cmp)

{

    constexpr float pi = 3.14159265358979323;

    constexpr float pi2 = 1.570796326794896619;

    pto::TADDS(tmp2, tmp1, pi);

    pto::TSUBS(tmp3, tmp1, pi);

    ATAN_SYNC_V;

    pto::TSEL(tmp2, cmp, tmp2, tmp3, dst);

    ATAN_SYNC_V;

    pto::TCMPS(cmp, src1, 0.0, pto::CmpMode::LT);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, tmp2, tmp1, tmp3);

    ATAN_SYNC_V;

    pto::TEXPANDS(tmp1, pi2);

    pto::TEXPANDS(tmp2, -pi2);

    pto::TCMPS(cmp, src0, 0.0, pto::CmpMode::GT);

    ATAN_SYNC_V;

    pto::TSEL(tmp1, cmp, tmp1, tmp2, tmp3);

    ATAN_SYNC_V;

    pto::TEXPANDS(tmp2, 0.0);

    pto::TCMPS(cmp, src0, 0.0, pto::CmpMode::NE);

    ATAN_SYNC_V;

    pto::TSEL(tmp1, cmp, tmp1, tmp2, tmp3);

    ATAN_SYNC_V;

    pto::TCMPS(cmp, src1, 0.0, pto::CmpMode::NE);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, dst, tmp1, tmp3);

    ATAN_SYNC_V;

    pto::TEXPANDS(tmp1, NAN);

    pto::TCMP(cmp, src0, src0, pto::CmpMode::EQ);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, dst, tmp1, tmp3);

    ATAN_SYNC_V;

    pto::TEXPANDS(tmp1, NAN);

    pto::TCMP(cmp, src1, src1, pto::CmpMode::EQ);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, dst, tmp1, tmp3);

    ATAN_SYNC_V;

}



template <typename DATA, typename CMP>

TILEOP void Atan2Div(DATA dst, DATA src0, DATA src1, DATA tmp1, DATA tmp2, DATA tmp3, CMP cmp)

{

    pto::TDIV<pto::DivAlgorithm::HIGH_PRECISION>(dst, src0, src1);

    pto::TCMP(cmp, src0, src1, pto::CmpMode::NE);

    pto::TMULS(tmp1, src0, -1.0);

    pto::TEXPANDS(tmp2, 1.0);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, dst, tmp2, tmp3);

    ATAN_SYNC_V;

    pto::TEXPANDS(tmp2, -1.0);

    pto::TCMP(cmp, tmp1, src1, pto::CmpMode::NE);

    ATAN_SYNC_V;

    pto::TSEL(dst, cmp, dst, tmp2, tmp3);

    ATAN_SYNC_V;

}



#define OP_TILE_OP_ATAN2 TAtan2

template <typename DST, typename SRC0, typename SRC1, typename TMP>

TILEOP void TAtan2(DST dst, SRC0 src0, SRC1 src1, TMP tmp)

{

    constexpr int64_t NUM_4 = 4;

    constexpr int64_t NUM_8 = 8;

    constexpr int64_t NUM_16 = 16;

    constexpr int64_t NUM_32 = 32;

    size_t dstShape[MAX_DIMS];

    AtanGetShape(dst, dstShape);

    constexpr size_t dstDtypeSize = sizeof(typename DST::Type);

    constexpr auto tileH = TileOp::GetTensorTileShapeDim<DST, DIM_4TH, MAX_DIMS>();

    constexpr auto tileW = TileOp::GetTensorTileShapeDim<DST, DIM_5TH, MAX_DIMS>();

    constexpr auto cmpTileW = ((tileW + NUM_8 - 1) / NUM_8 + NUM_32 - 1) / NUM_32 * NUM_32;

    constexpr auto b2TileW = (tileW + NUM_16 - 1) / NUM_16 * NUM_16;

    using CmpTileDefine = pto::Tile<pto::TileType::Vec, uint8_t, tileH, cmpTileW, pto::BLayout::RowMajor, -1, -1>;

    using UIntTileDefine = pto::Tile<pto::TileType::Vec, uint16_t, tileH, b2TileW, pto::BLayout::RowMajor, -1, -1>;

    using HalfTileDefine = pto::Tile<pto::TileType::Vec, half, tileH, b2TileW, pto::BLayout::RowMajor, -1, -1>;

    auto dstTile = PtoTile<DST>(dst);

    auto src0Tile = PtoTile<SRC0>(src0);

    auto src1Tile = PtoTile<SRC1>(src1);

    auto tmp1Tile = PtoTile<DST>(dst);

    auto tmp2Tile = PtoTile<DST>(dst);

    auto tmp3Tile = PtoTile<DST>(dst);

    CmpTileDefine cmpTile(dstShape[DIM_4TH], (dstShape[DIM_5TH] + NUM_8 - 1) / NUM_8);

    UIntTileDefine dstUIntTile(dstShape[DIM_4TH], dstShape[DIM_5TH]);

    UIntTileDefine tmp2UIntTile(dstShape[DIM_4TH], dstShape[DIM_5TH]);

    HalfTileDefine dstHalfTile(dstShape[DIM_4TH], dstShape[DIM_5TH]);

    for (LoopVar n0Index = 0; n0Index < dstShape[DIM_1ST]; ++n0Index) {

        for (LoopVar n1Index = 0; n1Index < dstShape[DIM_2ND]; ++n1Index) {

            for (LoopVar n2Index = 0; n2Index < dstShape[DIM_3RD]; ++n2Index) {

                auto dstOffset = TileOffset(n0Index, n1Index, n2Index);

                dstTile.Assign(dst, dstOffset);

                src0Tile.Assign(src0, dstOffset);

                src1Tile.Assign(src1, dstOffset);

                auto tileOffset = GenTileOffset(dst, dstOffset);

                pto::TASSIGN(dstUIntTile, dst.GetAddr() + tileOffset * dstDtypeSize);

                pto::TASSIGN(dstHalfTile, dst.GetAddr() + tileOffset * dstDtypeSize);

                auto tmp1Offset = tileOffset * NUM_4;

                auto tmp2Offset = tmp1Offset + tileH * tileW;

                auto tmp3Offset = tmp2Offset + tileH * tileW;

                auto cmpOffset = tmp3Offset + tileH * tileW;

                tmp1Tile.Assign(tmp.GetAddr(), tmp1Offset);

                tmp2Tile.Assign(tmp.GetAddr(), tmp2Offset);

                tmp3Tile.Assign(tmp.GetAddr(), tmp3Offset);

                pto::TASSIGN(tmp2UIntTile, tmp.GetAddr() + tmp2Offset * dstDtypeSize);

                pto::TASSIGN(cmpTile, tmp.GetAddr() + cmpOffset * dstDtypeSize);

                Atan2Div(dstTile.Data(), src0Tile.Data(), src1Tile.Data(),

                    tmp1Tile.Data(), tmp2Tile.Data(), tmp3Tile.Data(), cmpTile);

                AtanCalc(tmp1Tile.Data(), dstTile.Data(), tmp2Tile.Data(), tmp3Tile.Data(), cmpTile);

                Atan2Cast(dstHalfTile, src0Tile.Data(), dstUIntTile, tmp2UIntTile, cmpTile);

                Atan2Sp(dstTile.Data(), src0Tile.Data(), src1Tile.Data(),

                    tmp1Tile.Data(), tmp2Tile.Data(), tmp3Tile.Data(), cmpTile);

            }

        }

    }

}



#undef ATAN_SYNC_V

#endif