/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file binary_scalar.h
 * \brief
 */

#ifndef TILEOP_TILE_OPERATOR_BINARY_SCALAR__H
#define TILEOP_TILE_OPERATOR_BINARY_SCALAR__H
#include "binary.h"

template <BinaryScalarOp op, auto PrecisionType = 0, typename LastUse, typename T0, typename T1, typename Scalar>
TILEOP void BinaryScalarComputeImpl(T0 dst, T1 src0, Scalar src1)
{
    constexpr auto n1 = Std::tuple_element<DIM_1ST, LastUse>::type::value;
    constexpr auto n2 = Std::tuple_element<DIM_2ND, LastUse>::type::value;
    if constexpr (op == BinaryScalarOp::ADD) {
        PTO_WITH_LAST_USE(pto::TADDS(dst, src0, src1), n1, n2);
        return;
    }

    if constexpr (op == BinaryScalarOp::SUB) {
        if constexpr (std::is_same<Scalar, half>::value) {
            PTO_WITH_LAST_USE(
                pto::TADDS(dst, src0, static_cast<half>(static_cast<float>(-1) * static_cast<float>(src1))), n1, n2);
        } else {
            PTO_WITH_LAST_USE(pto::TADDS(dst, src0, -src1), n1, n2);
        }
        return;
    }

    if constexpr (op == BinaryScalarOp::MUL) {
        PTO_WITH_LAST_USE(pto::TMULS(dst, src0, src1), n1, n2);
        return;
    }

    if constexpr (op == BinaryScalarOp::DIV) {
        PTO_WITH_LAST_USE(pto::TDIVS<PrecisionType>(dst, src0, src1), n1, n2);
        return;
    }

    if constexpr (op == BinaryScalarOp::MAX) {
        PTO_WITH_LAST_USE(pto::TMAXS(dst, src0, src1), n1, n2);
        return;
    }

    if constexpr (op == BinaryScalarOp::MIN) {
        PTO_WITH_LAST_USE(pto::TMINS(dst, src0, src1), n1, n2);
        return;
    }

    if constexpr (op == BinaryScalarOp::BITWISEAND) {
        pto::TANDS(dst, src0, src1);
        return;
    }

    if constexpr (op == BinaryScalarOp::BITWISEOR) {
        pto::TORS(dst, src0, src1);
        return;
    }

    if constexpr (op == BinaryScalarOp::MOD) {
        pto::TFMODS<PrecisionType>(dst, src0, src1);
        return;
    }

    if constexpr (op == BinaryScalarOp::LRELU) {
        pto::TLRELU(dst, src0, src1);
        return;
    }
}

template <BinaryScalarOp op, auto PrecisionType = 0, typename LastUse, typename T0, typename T1, typename Scalar>
TILEOP void BinaryScalarCompute(T0 dst, T1 src0, Scalar src1)
{
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();

    auto dstTile = PtoTile<T0>(dst);
    auto src0Tile = PtoTile<T1>(src0);
    for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
                auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
                dstTile.Assign(dst, tileOffsets);
                src0Tile.Assign(src0, tileOffsets);
                BinaryScalarComputeImpl<op, PrecisionType, LastUse>(dstTile.Data(), src0Tile.Data(), src1);
            }
        }
    }
}

#define OP_TILE_OP_ADDS TAddS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TAddS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::ADD, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_SUBS TSubS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TSubS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::SUB, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_MULS TMulS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TMulS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::MUL, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_DIVS TDivS
template <
    auto PrecisionType = pto::DivAlgorithm::DEFAULT, typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0,
    typename T1>
TILEOP void TDivS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::DIV, PrecisionType, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_MAXS TMaxS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TMaxS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::MAX, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_MINS TMinS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TMinS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::MIN, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_LRELU TLReLU
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TLReLU(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::LRELU, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_BITWISEANDS TBitwiseAndS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TBitwiseAndS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::BITWISEAND, 0, LastUse>(dst, src0, src1);
}

#define OP_TILE_OP_BITWISEORS TBitwiseOrS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TBitwiseOrS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::BITWISEOR, 0, LastUse>(dst, src0, src1);
}

TILEOP int gcds(int a, int b)
{
    if (a < 0) {
        a = 0 - a;
    }
    if (b < 0) {
        b = 0 - b;
    }
    while (a % b != 0) {
        int c = a % b;
        a = b;
        b = c;
    }
    return b;
}

#define OP_TILE_OP_GCDS TGcdS
template <typename Scalar, typename T0, typename T1>
TILEOP void TGcdS(T0 dst, T1 src0, Scalar src1)
{
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
    auto shape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
    auto shape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
    auto dstStride0 = dstLayout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
    auto dstStride1 = dstLayout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
    auto dstStride2 = dstLayout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
    auto dstStride3 = dstLayout.template GetStrideDim<DIM_4TH, MAX_DIMS>();
    constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, 3, 5>();
    constexpr auto dstTileW = TileOp::GetTensorTileShapeDim<T0, 4, 5>();
    auto src0Addr = (__ubuf__ typename T1::Type*)((uint64_t)(src0.GetAddr()));
    auto dstAddr = (__ubuf__ typename T0::Type*)((uint64_t)(dst.GetAddr()));

    set_flag(PIPE_V, PIPE_S, EVENT_ID7);
    wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
    for (LoopVar n = 0; n < shape0; n++) {
        for (LoopVar j = 0; j < shape1; j++) {
            for (LoopVar k = 0; k < shape2; k++) {
                for (LoopVar m = 0; m < shape3; m++) {
                    for (LoopVar i = 0; i < shape4; i++) {
                        int tmpStride = n * dstStride0 + j * dstStride1 + k * dstStride2 + m * dstStride3 + i;
                        dstAddr[tmpStride] = gcds(src0Addr[tmpStride], src1);
                    }
                }
            }
        }
    }
    set_flag(PIPE_S, PIPE_V, EVENT_ID7);
    wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
}

#define OP_TILE_OP_MODS TModS
template <
    auto PrecisionType = pto::FmodSAlgorithm::DEFAULT, typename LastUse = LastUse2Dim<0, 0>, typename Scalar,
    typename T0, typename T1>
TILEOP void TModS(T0 dst, T1 src0, Scalar src1)
{
    BinaryScalarCompute<BinaryScalarOp::MOD, PrecisionType, LastUse>(dst, src0, src1);
}

template <BinaryScalarOp op, auto PrecisionType = 0, typename T0, typename T1, typename Scalar, typename T2>
TILEOP void BinaryScalarTmpComputeImpl(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    if constexpr (op == BinaryScalarOp::BITWISEXOR) {
        pto::TXORS(dst, src0, src1, tmp);
        return;
    }
    if constexpr (op == BinaryScalarOp::REM) {
        pto::TREMS<PrecisionType>(dst, src0, src1, tmp);
        return;
    }
    if constexpr (op == BinaryScalarOp::POW) {
        pto::TPOWS<PrecisionType>(dst, src0, src1, tmp);
        return;
    }
}

template <BinaryScalarOp op, auto PrecisionType = 0, typename T0, typename T1, typename Scalar, typename T2>
TILEOP void BinaryScalarTmpCompute(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();

    auto dstTile = PtoTile<T0>(dst);
    auto src0Tile = PtoTile<T1>(src0);
    auto tmpTile = PtoTile<T2>(tmp);
    for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
                auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
                dstTile.Assign(dst, tileOffsets);
                src0Tile.Assign(src0, tileOffsets);
                tmpTile.Assign(tmp, tileOffsets);
                BinaryScalarTmpComputeImpl<op, PrecisionType>(dstTile.Data(), src0Tile.Data(), src1, tmpTile.Data());
            }
        }
    }
}

#define OP_TILE_OP_BITWISEXORS TBitwiseXorS
template <typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TBitwiseXorS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    BinaryScalarTmpCompute<BinaryScalarOp::BITWISEXOR, 0>(dst, src0, src1, tmp);
}

#define OP_TILE_OP_REMS TRemainderS
template <typename Scalar, auto PrecisionType = pto::RemSAlgorithm::DEFAULT, typename T0, typename T1, typename T2>
TILEOP void TRemainderS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    BinaryScalarTmpCompute<BinaryScalarOp::REM, PrecisionType>(dst, src0, src1, tmp);
}

#define OP_TILE_OP_POWS TPowS
template <auto PrecisionType = pto::PowAlgorithm::DEFAULT, typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TPowS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    BinaryScalarTmpCompute<BinaryScalarOp::POW, PrecisionType>(dst, src0, src1, tmp);
}

#define OP_TILE_OP_REMRS TRemainderRS
template <typename Scalar, auto PrecisionType = pto::RemAlgorithm::DEFAULT, typename T0, typename T1, typename T2>
TILEOP void TRemainderRS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
    auto shape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
    auto shape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
    auto dstTile = PtoTile<T0>(dst);
    auto src0Tile = PtoTile<T1>(src0);
    constexpr auto tmpTileH = TileOp::GetTensorTileShapeDim<T0, 3, 5>();
    constexpr auto tmpTileW = TileOp::GetTensorTileShapeDim<T2, 4, 5>();
    using tmp0TileDefine =
        pto::Tile<pto::TileType::Vec, typename T2::Type, tmpTileH, tmpTileW, pto::BLayout::RowMajor, -1, -1>;
    using tmp1TileDefine =
        pto::Tile<pto::TileType::Vec, typename T2::Type, 2, tmpTileW, pto::BLayout::RowMajor, -1, -1>;
    tmp0TileDefine tmp0Tile(shape3, shape4);
    tmp1TileDefine tmp1Tile(2, shape4);

    for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
                auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
                dstTile.Assign(dst, tileOffsets);
                src0Tile.Assign(src0, tileOffsets);
                pto::TASSIGN(tmp0Tile, (uint64_t)(tmp.GetAddr()));
                pto::TASSIGN(tmp1Tile, (uint64_t)(tmp.GetAddr() + shape3 * tmpTileW * sizeof(typename T2::Type)));
                pto::TEXPANDS(tmp0Tile, src1);
#ifdef __DAV_V220
                pipe_barrier(PIPE_V);
#endif
                pto::TREM<PrecisionType>(dstTile.Data(), tmp0Tile, src0Tile.Data(), tmp1Tile);
            }
        }
    }
}

#define OP_TILE_OP_FLOORDIVS TFloorDivS
template <typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TFloorDivS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
    static_assert(std::is_same_v<typename T1::Type, int32_t>);

    const auto dstLayout = dst.GetLayout();
    auto dstShape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto dstShape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto dstShape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
    auto dstShape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
    auto dstShape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();

    if (dstShape0 == 0 || dstShape1 == 0 || dstShape2 == 0 || dstShape3 == 0 || dstShape4 == 0) {
        return;
    }

    auto dstStride0 = dstLayout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
    auto dstStride1 = dstLayout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
    auto dstStride2 = dstLayout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
    auto dstStride3 = dstLayout.template GetStrideDim<DIM_4TH, MAX_DIMS>();

    constexpr auto tileW = TileOp::GetTensorTileShapeDim<T0, DIM_5TH, MAX_DIMS>();
    constexpr auto dstTypeSize = sizeof(typename T0::Type);

    for (LoopVar n0Index = 0; n0Index < dstShape0; n0Index++) {
        for (LoopVar n1Index = 0; n1Index < dstShape1; n1Index++) {
            for (LoopVar n2Index = 0; n2Index < dstShape2; n2Index++) {
                for (LoopVar n3Index = 0; n3Index < dstShape3; n3Index++) {
                    auto offset =
                        n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2 + n3Index * dstStride3;
#ifdef __DAV_V220
                    using IntTileDefine =
                        pto::Tile<pto::TileType::Vec, typename T0::Type, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
                    using FloatTileDefine =
                        pto::Tile<pto::TileType::Vec, float, 1, tileW, pto::BLayout::RowMajor, -1, -1>;

                    IntTileDefine src0Tile(1, dstShape4);
                    IntTileDefine dstTile(1, dstShape4);
                    FloatTileDefine tmp0Tile(1, dstShape4);
                    FloatTileDefine tmp1Tile(1, dstShape4);

                    pto::TASSIGN(tmp0Tile, (uint64_t)(tmp.GetAddr()));
                    pto::TASSIGN(tmp1Tile, (uint64_t)(tmp.GetAddr() + tileW * dstTypeSize));
                    pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + offset * dstTypeSize));
                    pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + offset * dstTypeSize));

                    pto::TCVT(tmp0Tile, src0Tile, pto::RoundMode::CAST_NONE, pto::SaturationMode::OFF);
                    pipe_barrier(PIPE_V);
                    pto::TDIVS(tmp1Tile, tmp0Tile, static_cast<float>(src1));
                    pipe_barrier(PIPE_V);
                    pto::TCVT(dstTile, tmp1Tile, pto::RoundMode::CAST_FLOOR);
                    pipe_barrier(PIPE_V);
#else
                    using DataTileDefine =
                        pto::Tile<pto::TileType::Vec, typename T0::Type, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
                    using MaskTileDefine =
                        pto::Tile<pto::TileType::Vec, uint8_t, 1, tileW * 4, pto::BLayout::RowMajor, -1, -1>;

                    DataTileDefine src0Tile(1, dstShape4);
                    DataTileDefine dstTile(1, dstShape4);
                    DataTileDefine tmp0DataTile(1, dstShape4);
                    DataTileDefine tmp1DataTile(1, dstShape4);
                    DataTileDefine tmp2DataTile(1, dstShape4);

                    MaskTileDefine tmp0MaskTile(1, dstShape4);
                    MaskTileDefine tmp1MaskTile(1, dstShape4);

                    pto::TASSIGN(tmp0DataTile, (uint64_t)(tmp.GetAddr()));
                    pto::TASSIGN(tmp1DataTile, (uint64_t)(tmp.GetAddr() + tileW * dstTypeSize));
                    pto::TASSIGN(tmp2DataTile, (uint64_t)(tmp.GetAddr() + 2 * tileW * dstTypeSize));
                    pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + offset * dstTypeSize));
                    pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + offset * dstTypeSize));

                    // Reuse the same tmp as packed mask storage
                    pto::TASSIGN(tmp0MaskTile, (uint64_t)(tmp.GetAddr()));
                    pto::TASSIGN(tmp1MaskTile, (uint64_t)(tmp.GetAddr() + tileW * dstTypeSize));

                    if (src1 == 0) {
                        constexpr int32_t pos = 0x7FFF7F7F, neg = 0x80008080;
                        pto::TCMPS(tmp0MaskTile, src0Tile, 0, CmpMode::LT);
                        pto::TSELS(dstTile, tmp0MaskTile, dstTile, tmp1DataTile, pos);
                        pto::TCMPS(tmp0MaskTile, src0Tile, 0, CmpMode::GE);
                        pto::TSELS(dstTile, tmp0MaskTile, dstTile, tmp1DataTile, neg);
                    } else {
                        uint8_t src1Mask = 0;
                        if (src1 < 0) {
                            src1Mask = 0xff;
                        }
                        pto::TCMPS(tmp0MaskTile, src0Tile, 0, CmpMode::LT);
                        pto::TXORS(tmp1MaskTile, tmp0MaskTile, src1Mask, dstTile); // packed mask of sign_differ
                        pto::TDIVS(dstTile, src0Tile, src1);                       // quot
                        pto::TMULS(tmp0DataTile, dstTile, -src1);
                        pto::TADD(tmp2DataTile, tmp0DataTile, src0Tile); // rem

                        pto::TCMPS(tmp0MaskTile, tmp2DataTile, 0, CmpMode::NE);
                        pto::TAND(tmp0MaskTile, tmp1MaskTile, tmp0MaskTile);
                        pto::TADDS(tmp2DataTile, dstTile, -1);
                        pto::TSEL(dstTile, tmp0MaskTile, tmp2DataTile, dstTile, tmp1DataTile);
                    }
#endif
                }
            }
        }
    }
}
#endif