/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file reduce.h
 * \brief
 */

#ifndef TILEOP_TILE_OPERATOR_REDUCE__H
#define TILEOP_TILE_OPERATOR_REDUCE__H
#include "pto_tile.h"
#include "utils/layout.h"
#include "utils/tile_tensor.h"

template <ReduceOp op, typename LastUse, typename T0, typename T1, typename T2>
TILEOP void ReduceComputeImpl(T0 dst, T1 src, T2 tmp)
{
    constexpr auto n1 = Std::tuple_element<DIM_1ST, LastUse>::type::value;
    constexpr auto n2 = Std::tuple_element<DIM_2ND, LastUse>::type::value;
    constexpr auto n3 = Std::tuple_element<DIM_3RD, LastUse>::type::value;
    if constexpr (op == ReduceOp::SUM) {
        PTO_WITH_LAST_USE(pto::TROWSUM(dst, src, tmp), n1, n2, n3);
    } else if constexpr (op == ReduceOp::MAX) {
        PTO_WITH_LAST_USE(pto::TROWMAX(dst, src, tmp), n1, n2, n3);
    } else if constexpr (op == ReduceOp::MIN) {
        PTO_WITH_LAST_USE(pto::TROWMIN(dst, src, tmp), n1, n2, n3);
    } else if constexpr (op == ReduceOp::PROD) {
        PTO_WITH_LAST_USE(pto::TROWPROD(dst, src, tmp), n1, n2, n3);
    } else if constexpr (op == ReduceOp::ARGMAX) {
        pto::TROWARGMAX(dst, src, tmp);
    } else if constexpr (op == ReduceOp::ARGMIN) {
        pto::TROWARGMIN(dst, src, tmp);
    }
}

template <ReduceOp op, typename LastUse, typename T0, typename T1, typename T2>
TILEOP void ReduceLastAxisCompute(T0 dst, T1 src, T2 tmp)
{
    constexpr auto srcShapeSize = Std::tuple_size<typename T1::Shape>::value;
    constexpr auto dstShapeSize = Std::tuple_size<typename T0::Shape>::value;
    constexpr auto tmpShapeSize = Std::tuple_size<typename T2::Shape>::value;
    constexpr auto tmpTileH = TileOp::GetTensorTileShapeDim<T2, 3, 5>();
    constexpr auto tmpTileW = TileOp::GetTensorTileShapeDim<T2, 4, 5>();
    using TmpTileDefine = pto::Tile<
        pto::TileType::Vec, typename T2::Type, tmpTileH, tmpTileW, pto::BLayout::RowMajor, tmpTileH, tmpTileW>;
    TmpTileDefine tmpTile;

    constexpr size_t expectSize = 5;
    const auto dstLayout = dst.GetLayout();
    auto dstShape0 = dstLayout.template GetShapeDim<0, expectSize>();
    auto dstShape1 = dstLayout.template GetShapeDim<1, expectSize>();
    auto dstShape2 = dstLayout.template GetShapeDim<2, expectSize>();
    auto dstShape3 = dstLayout.template GetShapeDim<3, expectSize>();
    auto dstShape4 = dstLayout.template GetShapeDim<4, expectSize>();
    auto dstStride0 = dstLayout.template GetStrideDim<0, expectSize>();
    auto dstStride1 = dstLayout.template GetStrideDim<1, expectSize>();
    auto dstStride2 = dstLayout.template GetStrideDim<2, expectSize>();
    constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, 3, 5>();
    constexpr auto dstTileW = TileOp::GetTensorTileShapeDim<T0, 4, 5>();

    const auto srcLayout = src.GetLayout();
    auto srcShape0 = srcLayout.template GetShapeDim<0, expectSize>();
    auto srcShape1 = srcLayout.template GetShapeDim<1, expectSize>();
    auto srcShape2 = srcLayout.template GetShapeDim<2, expectSize>();
    auto srcShape3 = srcLayout.template GetShapeDim<3, expectSize>();
    auto srcShape4 = srcLayout.template GetShapeDim<4, expectSize>();
    if (srcShape0 == 0 || srcShape1 == 0 || srcShape2 == 0 || srcShape3 == 0 || srcShape4 == 0) {
        return;
    }
    auto srcStride0 = srcLayout.template GetStrideDim<0, expectSize>();
    auto srcStride1 = srcLayout.template GetStrideDim<1, expectSize>();
    auto srcStride2 = srcLayout.template GetStrideDim<2, expectSize>();
    constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T1, 3, 5>();
    constexpr auto srcTileW = TileOp::GetTensorTileShapeDim<T1, 4, 5>();
    constexpr auto srcTypeSize = sizeof(typename T1::Type);
    constexpr auto dstTypeSize = sizeof(typename T0::Type);
    for (LoopVar n0Index = 0; n0Index < dstShape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < dstShape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < dstShape2; ++n2Index) {
                using DstTileDefine = typename std::conditional<
                    (dstTileW == 1),
                    pto::Tile<
                        pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::ColMajor, -1, -1>,
                    pto::Tile<
                        pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1,
                        -1> >::type;
                using SrcTileDefine = pto::Tile<
                    pto::TileType::Vec, typename T1::Type, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
                DstTileDefine dstTile(dstShape3, dstShape4);
                SrcTileDefine srcTile(srcShape3, srcShape4);
                auto dstOffset = n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2;
                auto srcOffset = n0Index * srcStride0 + n1Index * srcStride1 + n2Index * srcStride2;
                pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + dstOffset * dstTypeSize));
                pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * srcTypeSize));
                pto::TASSIGN(tmpTile, (uint64_t)(tmp.GetAddr()));
                ReduceComputeImpl<op, LastUse>(dstTile, srcTile, tmpTile);
            }
        }
    }
}

#define OP_TILE_OP_ROWSUMSINGLE TRowSumSingle
template <typename LastUse = LastUse3Dim<0, 0, 0>, typename T0, typename T1, typename T2>
TILEOP void TRowSumSingle(T0 dst, T1 src, T2 tmp)
{
    ReduceLastAxisCompute<ReduceOp::SUM, LastUse>(dst, src, tmp);
}

template <ReduceOp op, typename T0, typename T1, typename T2, typename T3>
TILEOP void ArgReduceLastAxisCompute(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    constexpr auto tmpTileH = TileOp::GetTensorTileShapeDim<T3, 3, 5>();
    constexpr auto tmpTileW = TileOp::GetTensorTileShapeDim<T3, 4, 5>();
    using TmpTileDefine = pto::Tile<
        pto::TileType::Vec, typename T2::Type, tmpTileH, tmpTileW, pto::BLayout::RowMajor, tmpTileH, tmpTileW>;
    TmpTileDefine tmpTile;

    constexpr size_t expectSize = 5;
    const auto dstLayout = dstValue.GetLayout();
    auto dstShape0 = dstLayout.template GetShapeDim<0, expectSize>();
    auto dstShape1 = dstLayout.template GetShapeDim<1, expectSize>();
    auto dstShape2 = dstLayout.template GetShapeDim<2, expectSize>();
    auto dstShape3 = dstLayout.template GetShapeDim<3, expectSize>();
    auto dstShape4 = dstLayout.template GetShapeDim<4, expectSize>();
    auto dstStride0 = dstLayout.template GetStrideDim<0, expectSize>();
    auto dstStride1 = dstLayout.template GetStrideDim<1, expectSize>();
    auto dstStride2 = dstLayout.template GetStrideDim<2, expectSize>();
    constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, 3, 5>();
    constexpr auto dstTileW = TileOp::GetTensorTileShapeDim<T0, 4, 5>();

    const auto srcLayout = src.GetLayout();
    auto srcShape0 = srcLayout.template GetShapeDim<0, expectSize>();
    auto srcShape1 = srcLayout.template GetShapeDim<1, expectSize>();
    auto srcShape2 = srcLayout.template GetShapeDim<2, expectSize>();
    auto srcShape3 = srcLayout.template GetShapeDim<3, expectSize>();
    auto srcShape4 = srcLayout.template GetShapeDim<4, expectSize>();
    if (srcShape0 == 0 || srcShape1 == 0 || srcShape2 == 0 || srcShape3 == 0 || srcShape4 == 0) {
        return;
    }
    auto srcStride0 = srcLayout.template GetStrideDim<0, expectSize>();
    auto srcStride1 = srcLayout.template GetStrideDim<1, expectSize>();
    auto srcStride2 = srcLayout.template GetStrideDim<2, expectSize>();
    constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T2, 3, 5>();
    constexpr auto srcTileW = TileOp::GetTensorTileShapeDim<T2, 4, 5>();
    constexpr auto srcTypeSize = sizeof(typename T2::Type);
    constexpr auto dstTypeSize = sizeof(typename T0::Type);
    for (LoopVar n0Index = 0; n0Index < dstShape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < dstShape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < dstShape2; ++n2Index) {
                using DstValueTileDefine = typename std::conditional<
                    (dstTileW == 1),
                    pto::Tile<
                        pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::ColMajor, -1, -1>,
                    pto::Tile<
                        pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1,
                        -1> >::type;
                using DstIndexTileDefine = typename std::conditional<
                    (dstTileW == 1),
                    pto::Tile<
                        pto::TileType::Vec, typename T1::Type, dstTileH, dstTileW, pto::BLayout::ColMajor, -1, -1>,
                    pto::Tile<
                        pto::TileType::Vec, typename T1::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1,
                        -1> >::type;
                using SrcTileDefine = pto::Tile<
                    pto::TileType::Vec, typename T2::Type, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
                DstValueTileDefine dstValueTile(dstShape3, dstShape4);
                DstIndexTileDefine dstIndexTile(dstShape3, dstShape4);
                SrcTileDefine srcTile(srcShape3, srcShape4);
                auto dstOffset = n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2;
                auto srcOffset = n0Index * srcStride0 + n1Index * srcStride1 + n2Index * srcStride2;
                pto::TASSIGN(dstValueTile, (uint64_t)(dstValue.GetAddr() + dstOffset * dstTypeSize));
                pto::TASSIGN(dstIndexTile, (uint64_t)(dstIndex.GetAddr() + dstOffset * dstTypeSize));
                pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * srcTypeSize));
                pto::TASSIGN(tmpTile, (uint64_t)(tmp.GetAddr()));
                if constexpr (op == ReduceOp::ROWARGMAXWITHVALUE) {
                    pto::TROWARGMAX(dstValueTile, dstIndexTile, srcTile, tmpTile);
                } else if constexpr (op == ReduceOp::ROWARGMINWITHVALUE) {
                    pto::TROWARGMIN(dstValueTile, dstIndexTile, srcTile, tmpTile);
                }
            }
        }
    }
}

#define OP_TILE_OP_ROWARGMAXWITHVALUESINGLE TRowArgMaxWithValueSingle
template <typename LastUse = LastUse4Dim<0, 0, 0, 0>, typename T0, typename T1, typename T2, typename T3>
TILEOP void TRowArgMaxWithValueSingle(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    ArgReduceLastAxisCompute<ReduceOp::ROWARGMAXWITHVALUE>(dstValue, dstIndex, src, tmp);
}

#define OP_TILE_OP_ROWARGMINWITHVALUESINGLE TRowArgMinWithValueSingle
template <typename LastUse = LastUse4Dim<0, 0, 0, 0>, typename T0, typename T1, typename T2, typename T3>
TILEOP void TRowArgMinWithValueSingle(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    ArgReduceLastAxisCompute<ReduceOp::ROWARGMINWITHVALUE>(dstValue, dstIndex, src, tmp);
}

#define OP_TILE_OP_ROWMAXSINGLE TRowMaxSingle
template <typename LastUse = LastUse3Dim<0, 0, 0>, typename T0, typename T1, typename T2>
TILEOP void TRowMaxSingle(T0 dst, T1 src, T2 tmp)
{
    ReduceLastAxisCompute<ReduceOp::MAX, LastUse>(dst, src, tmp);
}

#define OP_TILE_OP_ROWMINSINGLE TRowMinSingle
template <typename LastUse = LastUse3Dim<0, 0, 0>, typename T0, typename T1, typename T2>
TILEOP void TRowMinSingle(T0 dst, T1 src, T2 tmp)
{
    ReduceLastAxisCompute<ReduceOp::MIN, LastUse>(dst, src, tmp);
}

#define OP_TILE_OP_ROWPRODSINGLE TRowProdSingle
template <typename LastUse = LastUse3Dim<0, 0, 0>, typename T0, typename T1, typename T2>
TILEOP void TRowProdSingle(T0 dst, T1 src, T2 tmp)
{
    ReduceLastAxisCompute<ReduceOp::PROD, LastUse>(dst, src, tmp);
}

template <ReduceOp op, int axis, typename T0, typename T1>
TILEOP void TRowMaxMinProdLineDynamic(T0 dst, T1 src)
{
    constexpr auto srcShapeSize = Std::tuple_size<typename T1::Shape>::value;
    constexpr auto dstShapeSize = Std::tuple_size<typename T0::Shape>::value;
    constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, axis + dstShapeSize - 5>();
    constexpr auto dstTileW =
        TileOp::GetAnyAxisMergeResult<axis + dstShapeSize - 3, dstShapeSize, typename T0::TileShape>();
    constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T1, axis + srcShapeSize - 5>();
    constexpr auto srcTileW =
        TileOp::GetAnyAxisMergeResult<axis + srcShapeSize - 3, srcShapeSize, typename T1::TileShape>();
    using DstTileDefine =
        pto::Tile<pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
    using SrcTileDefine =
        pto::Tile<pto::TileType::Vec, typename T1::Type, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
    constexpr size_t expectSize = 5;
    constexpr auto typeSize = sizeof(typename T1::Type);
    const auto dstLayout = dst.GetLayout();
    const auto srcLayout = src.GetLayout();
    size_t dstShape[] = {
        static_cast<size_t>(dstLayout.template GetShapeDim<0, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<1, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<2, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<3, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<4, expectSize>())};
    size_t dstStride[] = {
        static_cast<size_t>(dstLayout.template GetStrideDim<0, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<1, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<2, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<3, expectSize>())};
    size_t srcShape[] = {
        static_cast<size_t>(srcLayout.template GetShapeDim<0, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<1, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<2, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<3, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<4, expectSize>())};
    size_t srcStride[] = {
        static_cast<size_t>(srcLayout.template GetStrideDim<0, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<1, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<2, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<3, expectSize>())};
    for (LoopVar n0Index = 0, n0Size = (axis == 0 ? (size_t)1 : dstShape[0]); n0Index < n0Size; ++n0Index) {
        for (LoopVar n1Index = 0, n1Size = (axis == 1 ? (size_t)1 : dstShape[1]); n1Index < n1Size; ++n1Index) {
            for (LoopVar n2Index = 0, n2Size = (axis == 2 ? (size_t)1 : dstShape[2]); n2Index < n2Size; ++n2Index) {
                for (LoopVar n3Index = 0, n3Size = (axis == 3 ? (size_t)1 : dstShape[3]); n3Index < n3Size; ++n3Index) {
                    DstTileDefine dstTile(dstShape[axis], dstShape[4]);
                    SrcTileDefine srcTile(srcShape[axis], srcShape[4]);
                    auto dstOffset = n0Index * dstStride[0] + n1Index * dstStride[1] + n2Index * dstStride[2] +
                                     n3Index * dstStride[3];
                    auto srcOffset = n0Index * srcStride[0] + n1Index * srcStride[1] + n2Index * srcStride[2] +
                                     n3Index * srcStride[3];
                    pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + dstOffset * typeSize));
                    pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * typeSize));
                    if constexpr (op == ReduceOp::MAX) {
                        pto::TCOLMAX(dstTile, srcTile);
                    } else if constexpr (op == ReduceOp::MIN) {
                        pto::TCOLMIN(dstTile, srcTile);
                    } else if constexpr (op == ReduceOp::PROD) {
                        pto::TCOLPROD(dstTile, srcTile);
                    }
                }
            }
        }
    }
}

#define OP_TILE_OP_ROWMAXLINE TRowMaxLine
template <int axis, typename T0, typename T1>
TILEOP void TRowMaxLine(T0 dst, T1 src)
{
    TRowMaxMinProdLineDynamic<ReduceOp::MAX, axis>(dst, src);
}

#define OP_TILE_OP_ROWMINLINE TRowMinLine
template <int axis, typename T0, typename T1>
TILEOP void TRowMinLine(T0 dst, T1 src)
{
    TRowMaxMinProdLineDynamic<ReduceOp::MIN, axis>(dst, src);
}

#define OP_TILE_OP_ROWPRODLINE TRowProdLine
template <int axis, typename T0, typename T1>
TILEOP void TRowProdLine(T0 dst, T1 src)
{
    TRowMaxMinProdLineDynamic<ReduceOp::PROD, axis>(dst, src);
}

template <
    int axis, ReduceOp op, typename DstTileDefine, typename SrcTileDefine, typename TmpTileDefine, typename T0,
    typename T1, typename T2>
TILEOP void ColReduceWithTmpImp(T0 dst, T1 src, T2 tmp)
{
    constexpr size_t expectSize = 5;
    constexpr auto srcTypeSize = sizeof(typename T1::Type);
    constexpr auto dstTypeSize = sizeof(typename T0::Type);
    const auto dstLayout = dst.GetLayout();
    const auto srcLayout = src.GetLayout();
    const auto tmpLayout = tmp.GetLayout();
    size_t dstShape[] = {
        static_cast<size_t>(dstLayout.template GetShapeDim<0, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<1, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<2, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<3, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<4, expectSize>())};
    size_t dstStride[] = {
        static_cast<size_t>(dstLayout.template GetStrideDim<0, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<1, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<2, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<3, expectSize>())};
    size_t srcShape[] = {
        static_cast<size_t>(srcLayout.template GetShapeDim<0, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<1, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<2, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<3, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<4, expectSize>())};
    size_t srcStride[] = {
        static_cast<size_t>(srcLayout.template GetStrideDim<0, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<1, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<2, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<3, expectSize>())};
    for (LoopVar n0Index = 0, n0Size = (axis == 0 ? (size_t)1 : dstShape[0]); n0Index < n0Size; ++n0Index) {
        for (LoopVar n1Index = 0, n1Size = (axis == 1 ? (size_t)1 : dstShape[1]); n1Index < n1Size; ++n1Index) {
            for (LoopVar n2Index = 0, n2Size = (axis == 2 ? (size_t)1 : dstShape[2]); n2Index < n2Size; ++n2Index) {
                for (LoopVar n3Index = 0, n3Size = (axis == 3 ? (size_t)1 : dstShape[3]); n3Index < n3Size; ++n3Index) {
                    DstTileDefine dstTile(dstShape[axis], dstShape[4]);
                    SrcTileDefine srcTile(srcShape[axis], srcShape[4]);
                    TmpTileDefine tmpTile;
                    auto dstOffset = n0Index * dstStride[0] + n1Index * dstStride[1] + n2Index * dstStride[2] +
                                     n3Index * dstStride[3];
                    auto srcOffset = n0Index * srcStride[0] + n1Index * srcStride[1] + n2Index * srcStride[2] +
                                     n3Index * srcStride[3];
                    pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + dstOffset * dstTypeSize));
                    pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * srcTypeSize));
                    pto::TASSIGN(tmpTile, (uint64_t)(tmp.GetAddr()));
                    if constexpr (op == ReduceOp::SUM) {
                        pto::TCOLSUM(dstTile, srcTile, tmpTile, true);
                    } else if constexpr (op == ReduceOp::ARGMAX) {
                        pto::TCOLARGMAX(dstTile, srcTile, tmpTile);
                    } else if constexpr (op == ReduceOp::ARGMIN) {
                        pto::TCOLARGMIN(dstTile, srcTile, tmpTile);
                    }
                }
            }
        }
    }
}

template <int axis, ReduceOp op, typename T0, typename T1, typename T2>
TILEOP void ColReduceWithTmp(T0 dst, T1 src, T2 tmp)
{
    constexpr auto srcShapeSize = Std::tuple_size<typename T1::Shape>::value;
    constexpr auto dstShapeSize = Std::tuple_size<typename T0::Shape>::value;
    constexpr auto tmpShapeSize = Std::tuple_size<typename T2::Shape>::value;
    constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, axis + dstShapeSize - 5>();
    constexpr auto dstTileW =
        TileOp::GetAnyAxisMergeResult<axis + dstShapeSize - 3, dstShapeSize, typename T0::TileShape>();
    constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T1, axis + srcShapeSize - 5>();
    constexpr auto srcTileW =
        TileOp::GetAnyAxisMergeResult<axis + srcShapeSize - 3, srcShapeSize, typename T1::TileShape>();
    using DstTileDefine =
        pto::Tile<pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
    using SrcTileDefine =
        pto::Tile<pto::TileType::Vec, typename T1::Type, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
    constexpr auto tmpTileH = TileOp::GetTensorTileShapeDim<T2, tmpShapeSize - 2>();
    constexpr auto tmpTileW = TileOp::GetTensorTileShapeDim<T2, tmpShapeSize - 1>();
    using TmpTileDefine = pto::Tile<
        pto::TileType::Vec, typename T2::Type, tmpTileH, tmpTileW, pto::BLayout::RowMajor, tmpTileH, tmpTileW>;
    ColReduceWithTmpImp<axis, op, DstTileDefine, SrcTileDefine, TmpTileDefine>(dst, src, tmp);
}

#define OP_TILE_OP_ROWSUMLINE TRowSumLine
template <int axis, typename T0, typename T1, typename T2>
TILEOP void TRowSumLine(T0 dst, T1 src, T2 tmp)
{
    ColReduceWithTmp<axis, ReduceOp::SUM, T0, T1, T2>(dst, src, tmp);
}

#define OP_TILE_OP_ROWARGMAXLINE TRowArgMaxLine
template <int axis, typename T0, typename T1, typename T2>
TILEOP void TRowArgMaxLine(T0 dst, T1 src, T2 tmp)
{
    ColReduceWithTmp<axis, ReduceOp::ARGMAX, T0, T1, T2>(dst, src, tmp);
}

template <
    int axis, ReduceOp op, typename DstValueTileDefine, typename DstIndexTileDefine, typename SrcTileDefine, typename TmpTileDefine, typename T0,
    typename T1, typename T2, typename T3>
TILEOP void ColArgReduceImp(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    constexpr size_t expectSize = 5;
    constexpr auto srcTypeSize = sizeof(typename T2::Type);
    constexpr auto dstTypeSize = sizeof(typename T0::Type);
    const auto dstLayout = dstValue.GetLayout();
    const auto srcLayout = src.GetLayout();
    const auto tmpLayout = tmp.GetLayout();
    size_t dstShape[] = {
        static_cast<size_t>(dstLayout.template GetShapeDim<0, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<1, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<2, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<3, expectSize>()),
        static_cast<size_t>(dstLayout.template GetShapeDim<4, expectSize>())};
    size_t dstStride[] = {
        static_cast<size_t>(dstLayout.template GetStrideDim<0, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<1, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<2, expectSize>()),
        static_cast<size_t>(dstLayout.template GetStrideDim<3, expectSize>())};
    size_t srcShape[] = {
        static_cast<size_t>(srcLayout.template GetShapeDim<0, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<1, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<2, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<3, expectSize>()),
        static_cast<size_t>(srcLayout.template GetShapeDim<4, expectSize>())};
    size_t srcStride[] = {
        static_cast<size_t>(srcLayout.template GetStrideDim<0, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<1, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<2, expectSize>()),
        static_cast<size_t>(srcLayout.template GetStrideDim<3, expectSize>())};
    for (LoopVar n0Index = 0, n0Size = (axis == 0 ? (size_t)1 : dstShape[0]); n0Index < n0Size; ++n0Index) {
        for (LoopVar n1Index = 0, n1Size = (axis == 1 ? (size_t)1 : dstShape[1]); n1Index < n1Size; ++n1Index) {
            for (LoopVar n2Index = 0, n2Size = (axis == 2 ? (size_t)1 : dstShape[2]); n2Index < n2Size; ++n2Index) {
                for (LoopVar n3Index = 0, n3Size = (axis == 3 ? (size_t)1 : dstShape[3]); n3Index < n3Size; ++n3Index) {
                    DstValueTileDefine dstValueTile(dstShape[axis], dstShape[4]);
                    DstIndexTileDefine dstIndexTile(dstShape[axis], dstShape[4]);
                    SrcTileDefine srcTile(srcShape[axis], srcShape[4]);
                    TmpTileDefine tmpTile;
                    auto dstOffset = n0Index * dstStride[0] + n1Index * dstStride[1] + n2Index * dstStride[2] +
                                     n3Index * dstStride[3];
                    auto srcOffset = n0Index * srcStride[0] + n1Index * srcStride[1] + n2Index * srcStride[2] +
                                     n3Index * srcStride[3];
                    pto::TASSIGN(dstValueTile, (uint64_t)(dstValue.GetAddr() + dstOffset * dstTypeSize));
                    pto::TASSIGN(dstIndexTile, (uint64_t)(dstIndex.GetAddr() + dstOffset * dstTypeSize));
                    pto::TASSIGN(srcTile, (uint64_t)(src.GetAddr() + srcOffset * srcTypeSize));
                    pto::TASSIGN(tmpTile, (uint64_t)(tmp.GetAddr()));
                    if constexpr (op == ReduceOp::COLARGMAXWITHVALUE) {
                        pto::TCOLARGMAX(dstValueTile, dstIndexTile, srcTile, tmpTile);
                    } else if constexpr (op == ReduceOp::COLARGMINWITHVALUE) {
                        pto::TCOLARGMIN(dstValueTile, dstIndexTile, srcTile, tmpTile);
                    }
                }
            }
        }
    }
}

template <int axis, ReduceOp op, typename T0, typename T1, typename T2, typename T3>
TILEOP void ColArgReduce(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    constexpr auto srcShapeSize = Std::tuple_size<typename T2::Shape>::value;
    constexpr auto dstShapeSize = Std::tuple_size<typename T0::Shape>::value;
    constexpr auto tmpShapeSize = Std::tuple_size<typename T3::Shape>::value;
    constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, axis + dstShapeSize - 5>();
    constexpr auto dstTileW =
        TileOp::GetAnyAxisMergeResult<axis + dstShapeSize - 3, dstShapeSize, typename T0::TileShape>();
    constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T2, axis + srcShapeSize - 5>();
    constexpr auto srcTileW =
        TileOp::GetAnyAxisMergeResult<axis + srcShapeSize - 3, srcShapeSize, typename T2::TileShape>();
    using DstValueTileDefine =
        pto::Tile<pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
    using DstIndexTileDefine =
        pto::Tile<pto::TileType::Vec, typename T1::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
    using SrcTileDefine =
        pto::Tile<pto::TileType::Vec, typename T2::Type, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
    constexpr auto tmpTileH = TileOp::GetTensorTileShapeDim<T3, tmpShapeSize - 2>();
    constexpr auto tmpTileW = TileOp::GetTensorTileShapeDim<T3, tmpShapeSize - 1>();
    using TmpTileDefine = pto::Tile<
        pto::TileType::Vec, typename T3::Type, tmpTileH, tmpTileW, pto::BLayout::RowMajor, tmpTileH, tmpTileW>;
    ColArgReduceImp<axis, op, DstValueTileDefine, DstIndexTileDefine, SrcTileDefine, TmpTileDefine>(dstValue, dstIndex, src, tmp);
}

#define OP_TILE_OP_ROWARGMAXWITHVALUELINE TRowArgMaxWithValueLine
template <int axis, typename T0, typename T1, typename T2, typename T3>
TILEOP void TRowArgMaxWithValueLine(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    ColArgReduce<axis, ReduceOp::COLARGMAXWITHVALUE, T0, T1, T2, T3>(dstValue, dstIndex, src, tmp);
}

#define OP_TILE_OP_ROWARGMINWITHVALUELINE TRowArgMinWithValueLine
template <int axis, typename T0, typename T1, typename T2, typename T3>
TILEOP void TRowArgMinWithValueLine(T0 dstValue, T1 dstIndex, T2 src, T3 tmp)
{
    ColArgReduce<axis, ReduceOp::COLARGMINWITHVALUE, T0, T1, T2, T3>(dstValue, dstIndex, src, tmp);
}
#endif