/**
 * Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file compare.h
 * \brief
 */

#ifndef TILEOP_TILE_OPERATOR_COMPARE__H
#define TILEOP_TILE_OPERATOR_COMPARE__H
#include "utils/layout.h"
#include "utils/tile_tensor.h"

template <typename T>
struct CompareTmpBuffers {
    __ubuf__ uint8_t* vcmpBitResult;
    __ubuf__ uint8_t* startAddrUB;
    __ubuf__ T* zeroCondition;
    __ubuf__ T* oneCondition;
    __ubuf__ T* vselResult;
};

template <typename T>
struct CompareTileTypes {
    static constexpr int64_t COUNT_MAX = 1024;
    using SrcTile = pto::Tile<pto::TileType::Vec, typename T::Type, 1, COUNT_MAX, pto::BLayout::RowMajor, -1, -1>;
    using DstTile = pto::Tile<pto::TileType::Vec, uint8_t, 1, COUNT_MAX, pto::BLayout::RowMajor, -1, -1>;
    using CmpTile = pto::Tile<pto::TileType::Vec, uint8_t, 1, COUNT_MAX, pto::BLayout::RowMajor, -1, -1>;
    using TmpTile = pto::Tile<pto::TileType::Vec, half, 1, COUNT_MAX, pto::BLayout::RowMajor, -1, -1>;
};

template <typename T, typename TTmp>
TILEOP CompareTmpBuffers<T> InitCompareTmpBuffers(TTmp tmpbuf)
{
    constexpr uint64_t countNum = 4096 / sizeof(typename T::Type);
    const uint32_t ALIGNMENT = 32;
    const uint32_t vcmpBitsSize = (countNum + 7) / 8;

    uint64_t tmpbufAddr = tmpbuf.GetAddr();
    __ubuf__ uint8_t* currentPtr = reinterpret_cast<__ubuf__ uint8_t*>(tmpbufAddr);

    CompareTmpBuffers<T> buffers;
    buffers.vcmpBitResult = currentPtr;
    currentPtr += vcmpBitsSize;

    currentPtr = reinterpret_cast<__ubuf__ uint8_t*>(
        (reinterpret_cast<uintptr_t>(currentPtr) + ALIGNMENT - 1) & ~(ALIGNMENT - 1));
    buffers.startAddrUB = reinterpret_cast<__ubuf__ uint8_t*>(currentPtr);
    currentPtr += ALIGNMENT;

    currentPtr = reinterpret_cast<__ubuf__ uint8_t*>(
        (reinterpret_cast<uintptr_t>(currentPtr) + ALIGNMENT - 1) & ~(ALIGNMENT - 1));
    buffers.zeroCondition = reinterpret_cast<__ubuf__ T*>(currentPtr);
    currentPtr += countNum * sizeof(typename T::Type);

    currentPtr = reinterpret_cast<__ubuf__ uint8_t*>(
        (reinterpret_cast<uintptr_t>(currentPtr) + ALIGNMENT - 1) & ~(ALIGNMENT - 1));
    buffers.oneCondition = reinterpret_cast<__ubuf__ T*>(currentPtr);
    currentPtr += countNum * sizeof(typename T::Type);

    currentPtr = reinterpret_cast<__ubuf__ uint8_t*>(
        (reinterpret_cast<uintptr_t>(currentPtr) + ALIGNMENT - 1) & ~(ALIGNMENT - 1));
    buffers.vselResult = reinterpret_cast<__ubuf__ T*>(currentPtr);

    return buffers;
}

struct CompareLayoutInfo {
    size_t shape0, shape1, shape2, shape3, shape4, dstShape;
    size_t stride0, stride1, stride2, stride3;
    size_t dstStride0, dstStride1, dstStride2, dstStride3;
};

template <typename T, typename TDst>
TILEOP CompareLayoutInfo ExtractLayoutInfo(const T& src, const TDst& dst)
{
    constexpr size_t expectSize = 5;
    const auto srcLayout = src.GetLayout();
    const auto dstLayout = dst.GetLayout();

    CompareLayoutInfo info;
    info.shape0 = srcLayout.template GetShapeDim<0, expectSize>();
    info.shape1 = srcLayout.template GetShapeDim<1, expectSize>();
    info.shape2 = srcLayout.template GetShapeDim<2, expectSize>();
    info.shape3 = srcLayout.template GetShapeDim<3, expectSize>();
    info.shape4 = srcLayout.template GetShapeDim<4, expectSize>();
    info.dstShape = dstLayout.template GetShapeDim<4, expectSize>();

    info.stride0 = srcLayout.template GetStrideDim<0, expectSize>();
    info.stride1 = srcLayout.template GetStrideDim<1, expectSize>();
    info.stride2 = srcLayout.template GetStrideDim<2, expectSize>();
    info.stride3 = srcLayout.template GetStrideDim<3, expectSize>();

    info.dstStride0 = dstLayout.template GetStrideDim<0, expectSize>();
    info.dstStride1 = dstLayout.template GetStrideDim<1, expectSize>();
    info.dstStride2 = dstLayout.template GetStrideDim<2, expectSize>();
    info.dstStride3 = dstLayout.template GetStrideDim<3, expectSize>();

    return info;
}

template <int64_t cmpOp, typename DstTileType, typename SrcTileType>
TILEOP void ExecuteCompare(DstTileType& dst0, SrcTileType& src0Tile, SrcTileType& src1Tile)
{
    switch (static_cast<pto::CmpMode>(cmpOp)) {
        case pto::CmpMode::EQ:
            pto::TCMP(dst0, src0Tile, src1Tile, pto::CmpMode::EQ);
            break;
        case pto::CmpMode::NE:
            pto::TCMP(dst0, src0Tile, src1Tile, pto::CmpMode::NE);
            break;
        case pto::CmpMode::LT:
            pto::TCMP(dst0, src0Tile, src1Tile, pto::CmpMode::LT);
            break;
        case pto::CmpMode::LE:
            pto::TCMP(dst0, src0Tile, src1Tile, pto::CmpMode::LE);
            break;
        case pto::CmpMode::GT:
            pto::TCMP(dst0, src0Tile, src1Tile, pto::CmpMode::GT);
            break;
        case pto::CmpMode::GE:
            pto::TCMP(dst0, src0Tile, src1Tile, pto::CmpMode::GE);
            break;
    }
}

template <int64_t cmpOp, typename DstTileType, typename SrcTileType, typename TVal>
TILEOP void ExecuteCompareScalar(DstTileType& dst0, SrcTileType& srcTile, TVal scalarVal)
{
    switch (static_cast<pto::CmpMode>(cmpOp)) {
        case pto::CmpMode::EQ:
            pto::TCMPS(dst0, srcTile, scalarVal, pto::CmpMode::EQ);
            break;
        case pto::CmpMode::NE:
            pto::TCMPS(dst0, srcTile, scalarVal, pto::CmpMode::NE);
            break;
        case pto::CmpMode::LT:
            pto::TCMPS(dst0, srcTile, scalarVal, pto::CmpMode::LT);
            break;
        case pto::CmpMode::LE:
            pto::TCMPS(dst0, srcTile, scalarVal, pto::CmpMode::LE);
            break;
        case pto::CmpMode::GT:
            pto::TCMPS(dst0, srcTile, scalarVal, pto::CmpMode::GT);
            break;
        case pto::CmpMode::GE:
            pto::TCMPS(dst0, srcTile, scalarVal, pto::CmpMode::GE);
            break;
    }
}

template <
    typename T, typename SrcTileType, typename DstTileType, typename CmpTileType, typename TmpTileType,
    typename AddrUBTileType>
TILEOP void PostProcessMode0(
    DstTileType& bitResTile, CmpTileType& cmpResTile, SrcTileType& vselResultTile, SrcTileType& oneConditionTile,
    SrcTileType& zeroConditionTile, AddrUBTileType& startAddrUBTile, TmpTileType& tmpTile, __ubuf__ T* zeroCondition)
{
    pto::TEXPANDS(zeroConditionTile, 0.000000e+00f);
    pto::TEXPANDS(oneConditionTile, 1.000000e+00f);
#ifdef __DAV_V220
    pipe_barrier(PIPE_V);
#endif
    pto::TSEL(vselResultTile, cmpResTile, oneConditionTile, zeroConditionTile, startAddrUBTile);
#ifdef __DAV_V220
    pipe_barrier(PIPE_V);
#endif
    if constexpr (sizeof(typename T::Type) == 2) {
        pto::TCVT(bitResTile, vselResultTile, pto::RoundMode::CAST_NONE);
    } else if constexpr (sizeof(typename T::Type) == 4) {
        pto::TASSIGN(tmpTile, reinterpret_cast<uint64_t>(zeroCondition));
        pto::TCVT(tmpTile, vselResultTile, pto::RoundMode::CAST_NONE);
#ifdef __DAV_V220
        pipe_barrier(PIPE_V);
#endif
        pto::TCVT(bitResTile, tmpTile, pto::RoundMode::CAST_NONE);
    }
}

TILEOP void CalcOffsets(
    const CompareLayoutInfo& info, size_t n0Index, size_t n1Index, size_t n2Index, size_t n3Index, size_t& srcOffset,
    size_t& dstOffset)
{
    dstOffset =
        n0Index * info.dstStride0 + n1Index * info.dstStride1 + n2Index * info.dstStride2 + n3Index * info.dstStride3;
    srcOffset = n0Index * info.stride0 + n1Index * info.stride1 + n2Index * info.stride2 + n3Index * info.stride3;
}

template <typename T, typename Types, typename TileStartAddrUB>
TILEOP void InitCommonTiles(
    typename Types::SrcTile& vselResultTile, TileStartAddrUB& startAddrUBTile,
    typename Types::SrcTile& oneConditionTile, typename Types::SrcTile& zeroConditionTile,
    typename Types::DstTile& bitResTile, typename Types::CmpTile& cmpResTile, const CompareTmpBuffers<T>& buffers,
    uint64_t dstAddr, size_t shape4, size_t dstShape)
{
    vselResultTile = typename Types::SrcTile(1, shape4);
    oneConditionTile = typename Types::SrcTile(1, shape4);
    zeroConditionTile = typename Types::SrcTile(1, shape4);
    bitResTile = typename Types::DstTile(1, dstShape);
    cmpResTile = typename Types::CmpTile(1, dstShape);

    pto::TASSIGN(bitResTile, dstAddr);
    pto::TASSIGN(vselResultTile, reinterpret_cast<uint64_t>(buffers.vselResult));
    pto::TASSIGN(startAddrUBTile, reinterpret_cast<uint64_t>(buffers.startAddrUB));
    pto::TASSIGN(oneConditionTile, reinterpret_cast<uint64_t>(buffers.oneCondition));
    pto::TASSIGN(zeroConditionTile, reinterpret_cast<uint64_t>(buffers.zeroCondition));
    pto::TASSIGN(cmpResTile, reinterpret_cast<uint64_t>(buffers.vcmpBitResult));
}

template <int64_t cmpOp, int64_t mode, typename T, typename TDst, typename TTmp>
TILEOP void TCompare(TDst dst, T src0, T src1, TTmp tmpbuf)
{
    auto info = ExtractLayoutInfo(src0, dst);
    auto buffers = InitCompareTmpBuffers<T>(tmpbuf);
    constexpr auto dstTypeSize = sizeof(typename TDst::Type);
    constexpr auto srcTypeSize = sizeof(typename T::Type);
    constexpr unsigned alignUint8 = 32;
    constexpr unsigned addressUsed = 4;
    using Types = CompareTileTypes<T>;
    using SrcTile = typename Types::SrcTile;
    using DstTile = typename Types::DstTile;
    using CmpTile = typename Types::CmpTile;
    using TmpTile = typename Types::TmpTile;
    using TileStartAddrUB = pto::Tile<pto::TileType::Vec, uint8_t, 1, alignUint8, pto::BLayout::RowMajor, -1, -1>;

    constexpr uint64_t countBy4096 = 4096 / sizeof(typename T::Type);
    constexpr uint64_t elementsPerCount =
        (countBy4096 < static_cast<uint64_t>(Types::COUNT_MAX)) ? countBy4096 : static_cast<uint64_t>(Types::COUNT_MAX);
    constexpr uint64_t dstElementsPerCount = (mode == 0) ? elementsPerCount : ((elementsPerCount + 7) / 8);

    uint64_t numCountPerLine = info.shape4 / elementsPerCount;
    uint64_t elementsRemainPerLine = info.shape4 % elementsPerCount;

    for (LoopVar n0Index = 0; n0Index < info.shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < info.shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < info.shape2; ++n2Index) {
                for (LoopVar n3Index = 0; n3Index < info.shape3; ++n3Index) {
                    size_t srcOffset, dstOffset;
                    CalcOffsets(info, n0Index, n1Index, n2Index, n3Index, srcOffset, dstOffset);

                    for (LoopVar j = 0; j < numCountPerLine; ++j) {
                        size_t curShape4 = elementsPerCount;
                        size_t curDstShape = (mode == 0) ? curShape4 : ((curShape4 + 7) / 8);

                        size_t curSrcOffset = srcOffset + j * elementsPerCount;
                        size_t curDstOffset = dstOffset + j * dstElementsPerCount;

                        uint64_t dstAddr = dst.GetAddr() + curDstOffset * dstTypeSize;

                        SrcTile src0Tile(1, curShape4), src1Tile(1, curShape4);
                        SrcTile vselResultTile, oneConditionTile, zeroConditionTile;
                        DstTile bitResTile;
                        CmpTile cmpResTile;
                        TileStartAddrUB startAddrUBTile(1, addressUsed);
                        TmpTile tmpTile(1, curShape4);

                        InitCommonTiles<T, Types>(
                            vselResultTile, startAddrUBTile, oneConditionTile, zeroConditionTile, bitResTile,
                            cmpResTile, buffers, dstAddr, curShape4, curDstShape);

                        pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + curSrcOffset * srcTypeSize));
                        pto::TASSIGN(src1Tile, (uint64_t)(src1.GetAddr() + curSrcOffset * srcTypeSize));

                        auto& dst0 = (mode == 0) ? cmpResTile : bitResTile;
                        ExecuteCompare<cmpOp>(dst0, src0Tile, src1Tile);

                        if constexpr (mode == 0) {
                            PostProcessMode0<T>(
                                bitResTile, cmpResTile, vselResultTile, oneConditionTile, zeroConditionTile,
                                startAddrUBTile, tmpTile, buffers.zeroCondition);
                        }
                    }

                    if (elementsRemainPerLine) {
                        size_t curShape4 = elementsRemainPerLine;
                        size_t curDstShape = (mode == 0) ? curShape4 : ((curShape4 + 7) / 8);

                        size_t curSrcOffset = srcOffset + numCountPerLine * elementsPerCount;
                        size_t curDstOffset = dstOffset + numCountPerLine * dstElementsPerCount;

                        uint64_t dstAddr = dst.GetAddr() + curDstOffset * dstTypeSize;

                        SrcTile src0Tile(1, curShape4), src1Tile(1, curShape4);
                        SrcTile vselResultTile, oneConditionTile, zeroConditionTile;
                        DstTile bitResTile;
                        CmpTile cmpResTile;
                        TileStartAddrUB startAddrUBTile(1, addressUsed);
                        TmpTile tmpTile(1, curShape4);

                        InitCommonTiles<T, Types>(
                            vselResultTile, startAddrUBTile, oneConditionTile, zeroConditionTile, bitResTile,
                            cmpResTile, buffers, dstAddr, curShape4, curDstShape);

                        pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + curSrcOffset * srcTypeSize));
                        pto::TASSIGN(src1Tile, (uint64_t)(src1.GetAddr() + curSrcOffset * srcTypeSize));

                        auto& dst0 = (mode == 0) ? cmpResTile : bitResTile;
                        ExecuteCompare<cmpOp>(dst0, src0Tile, src1Tile);

                        if constexpr (mode == 0) {
                            PostProcessMode0<T>(
                                bitResTile, cmpResTile, vselResultTile, oneConditionTile, zeroConditionTile,
                                startAddrUBTile, tmpTile, buffers.zeroCondition);
                        }
                    }
                }
            }
        }
    }
}

template <int64_t cmpOp, int64_t mode, typename TVal, typename T, typename TDst, typename TTmp>
TILEOP void TCompare(TDst dst, T src, TTmp tmpbuf, TVal scalarVal)
{
    auto buffers = InitCompareTmpBuffers<T>(tmpbuf);
    auto info = ExtractLayoutInfo(src, dst);
    constexpr auto dstTypeSize = sizeof(typename TDst::Type);
    constexpr auto srcTypeSize = sizeof(typename T::Type);
    constexpr unsigned alignUint8 = 32;
    constexpr unsigned addressUsed = 4;
    using Types = CompareTileTypes<T>;
    using DstTile = typename Types::DstTile;
    using CmpTile = typename Types::CmpTile;
    using SrcTile = typename Types::SrcTile;
    using TmpTile = typename Types::TmpTile;
    using TileStartAddrUB = pto::Tile<pto::TileType::Vec, uint8_t, 1, alignUint8, pto::BLayout::RowMajor, -1, -1>;

    constexpr uint64_t countBy4096 = 4096 / sizeof(typename T::Type);
    constexpr uint64_t elementsPerCount =
        (countBy4096 < static_cast<uint64_t>(Types::COUNT_MAX)) ? countBy4096 : static_cast<uint64_t>(Types::COUNT_MAX);
    constexpr uint64_t dstElementsPerCount = (mode == 0) ? elementsPerCount : ((elementsPerCount + 7) / 8);

    uint64_t numCountPerLine = info.shape4 / elementsPerCount;
    uint64_t elementsRemainPerLine = info.shape4 % elementsPerCount;

    for (LoopVar n0Index = 0; n0Index < info.shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < info.shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < info.shape2; ++n2Index) {
                for (LoopVar n3Index = 0; n3Index < info.shape3; ++n3Index) {
                    size_t srcOffset, dstOffset;
                    CalcOffsets(info, n0Index, n1Index, n2Index, n3Index, srcOffset, dstOffset);

                    for (LoopVar j = 0; j < numCountPerLine; ++j) {
                        size_t curShape4 = elementsPerCount;
                        size_t curDstShape = (mode == 0) ? curShape4 : ((curShape4 + 7) / 8);

                        size_t curSrcOffset = srcOffset + j * elementsPerCount;
                        size_t curDstOffset = dstOffset + j * dstElementsPerCount;

                        uint64_t dstAddr = dst.GetAddr() + curDstOffset * dstTypeSize;

                        SrcTile srcTile(1, curShape4);
                        SrcTile vselResultTile, oneConditionTile, zeroConditionTile;
                        DstTile bitResTile;
                        CmpTile cmpResTile;
                        TileStartAddrUB startAddrUBTile(1, addressUsed);
                        TmpTile tmpTile(1, curShape4);

                        InitCommonTiles<T, Types>(
                            vselResultTile, startAddrUBTile, oneConditionTile, zeroConditionTile, bitResTile,
                            cmpResTile, buffers, dstAddr, curShape4, curDstShape);

                        pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + curSrcOffset * srcTypeSize));

                        auto& dst0 = (mode == 0) ? cmpResTile : bitResTile;
                        ExecuteCompareScalar<cmpOp>(dst0, srcTile, scalarVal);

                        if constexpr (mode == 0) {
                            PostProcessMode0<T>(
                                bitResTile, cmpResTile, vselResultTile, oneConditionTile, zeroConditionTile,
                                startAddrUBTile, tmpTile, buffers.zeroCondition);
                        }
                    }

                    if (elementsRemainPerLine) {
                        size_t curShape4 = elementsRemainPerLine;
                        size_t curDstShape = (mode == 0) ? curShape4 : ((curShape4 + 7) / 8);

                        size_t curSrcOffset = srcOffset + numCountPerLine * elementsPerCount;
                        size_t curDstOffset = dstOffset + numCountPerLine * dstElementsPerCount;

                        uint64_t dstAddr = dst.GetAddr() + curDstOffset * dstTypeSize;

                        SrcTile srcTile(1, curShape4);
                        SrcTile vselResultTile, oneConditionTile, zeroConditionTile;
                        DstTile bitResTile;
                        CmpTile cmpResTile;
                        TileStartAddrUB startAddrUBTile(1, addressUsed);
                        TmpTile tmpTile(1, curShape4);

                        InitCommonTiles<T, Types>(
                            vselResultTile, startAddrUBTile, oneConditionTile, zeroConditionTile, bitResTile,
                            cmpResTile, buffers, dstAddr, curShape4, curDstShape);

                        pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + curSrcOffset * srcTypeSize));

                        auto& dst0 = (mode == 0) ? cmpResTile : bitResTile;
                        ExecuteCompareScalar<cmpOp>(dst0, srcTile, scalarVal);

                        if constexpr (mode == 0) {
                            PostProcessMode0<T>(
                                bitResTile, cmpResTile, vselResultTile, oneConditionTile, zeroConditionTile,
                                startAddrUBTile, tmpTile, buffers.zeroCondition);
                        }
                    }
                }
            }
        }
    }
}
#endif