/**
 * Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file cum_operation.h
 * \brief
 */

#ifndef TILEOP_TILE_OPERATOR_CUM_OPERATION__H
#define TILEOP_TILE_OPERATOR_CUM_OPERATION__H
#include "utils/layout.h"
#include "utils/tile_tensor.h"
#include <array>

template <typename T0, typename T1, unsigned tileH, unsigned tileW, unsigned dstTypeSize, int is_sum>
TILEOP void CumOperationTool(T0 dst, T1 src, uint64_t tmpStride)
{
    using tmpTileDefine =
        pto::Tile<pto::TileType::Vec, typename T0::Type, tileH, tileW, pto::BLayout::RowMajor, -1, -1>;
    tmpTileDefine tmpDstTile(tileH, tileW), tmpSrcTile(tileH, tileW);
    pto::TASSIGN(tmpDstTile, (uint64_t)(dst.GetAddr() + tmpStride));
    pto::TASSIGN(tmpSrcTile, (uint64_t)(src.GetAddr() + tmpStride));
    pto::TMOV(tmpDstTile, tmpSrcTile);

#pragma clang loop unroll(disable)
    for (LoopVar i = 1; i < tileH;) {
#ifdef __DAV_V220
        pipe_barrier(PIPE_V);
#endif
        using TileDefine =
            pto::Tile<pto::TileType::Vec, typename T0::Type, tileH, tileW, pto::BLayout::RowMajor, -1, -1>;
        TileDefine src1Tile(tileH - i, tileW), src0Tile(tileH - i, tileW), dstTile(tileH - i, tileW);
        pto::TASSIGN(src0Tile, (uint64_t)(src.GetAddr() + tmpStride));
        pto::TASSIGN(src1Tile, (uint64_t)(src.GetAddr() + tmpStride + i * tileW * dstTypeSize));
        pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + tmpStride + i * tileW * dstTypeSize));
        if (is_sum) {
            pto::TADD(dstTile, src0Tile, src1Tile);
        } else {
            pto::TMUL(dstTile, src0Tile, src1Tile);
        }
#ifdef __DAV_V220
        pipe_barrier(PIPE_V);
#endif
        pto::TADDS(src1Tile, dstTile, 0);

        i = i * 2;
    }
}

template <typename T0, int is_sum>
TILEOP void CumOperationScalarTool(
    __ubuf__ typename T0::Type* dstAddr, __ubuf__ typename T0::Type* srcAddr, uint64_t stride, uint64_t lastShape,
    uint64_t idx)
{
    if (is_sum) {
        for (LoopVar tmpIdx = 0; tmpIdx < lastShape - idx; tmpIdx++) {
            dstAddr[stride + tmpIdx + idx] = srcAddr[stride + tmpIdx] + srcAddr[stride + tmpIdx + idx];
        }
    } else {
        for (LoopVar tmpIdx = 0; tmpIdx < lastShape - idx; tmpIdx++) {
            dstAddr[stride + tmpIdx + idx] = srcAddr[stride + tmpIdx] * srcAddr[stride + tmpIdx + idx];
        }
    }
    for (LoopVar tmpIdx = 0; tmpIdx < lastShape - idx; tmpIdx++) {
        srcAddr[stride + tmpIdx + idx] = dstAddr[stride + tmpIdx + idx];
    }
}

template <int axis, int is_sum, typename T0, typename T1>
TILEOP void TCumOperation(T0 dst, T1 src)
{
    constexpr size_t expectSize = 5;
    constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
    constexpr auto dstTypeSize = sizeof(typename T0::Type);
    const auto dstLayout = dst.GetLayout();
    auto n0DstStride = dstLayout.template GetStrideDim<0, expectSize>();
    auto n1DstStride = dstLayout.template GetStrideDim<1, expectSize>();
    auto n2DstStride = dstLayout.template GetStrideDim<2, expectSize>();
    auto n3DstStride = dstLayout.template GetStrideDim<3, expectSize>();
    auto n0DstShape = dstLayout.template GetShapeDim<0, expectSize>();
    auto n1DstShape = dstLayout.template GetShapeDim<1, expectSize>();
    auto n2DstShape = dstLayout.template GetShapeDim<2, expectSize>();
    auto n3DstShape = dstLayout.template GetShapeDim<3, expectSize>();
    auto n4DstShape = dstLayout.template GetShapeDim<4, expectSize>();
    constexpr auto dst1RawShape = Std::tuple_element<shapeSize - 1, typename T0::TileShape>::type::value;

    if constexpr (axis == 0) {
        constexpr auto tileH = Std::tuple_element<DIM_1ST, typename T0::TileShape>::type::value;
        constexpr auto tileW = TileOp::GetNonFirstAxisMergeResult<shapeSize, typename T0::TileShape>();
        uint64_t tmpStride = 0;
        CumOperationTool<T0, T1, tileH, tileW, dstTypeSize, is_sum>(dst, src, tmpStride);
        return;
    } else if constexpr (axis == 1) {
        int loops = n0DstShape;
        constexpr auto tileH = Std::tuple_element<shapeSize - 4, typename T0::TileShape>::type::value;
        constexpr auto dst2RawShape = Std::tuple_element<shapeSize - 2, typename T0::TileShape>::type::value;
        constexpr auto dst3RawShape = Std::tuple_element<shapeSize - 3, typename T0::TileShape>::type::value;
        constexpr int tileW = dst3RawShape * dst2RawShape * dst1RawShape;

        for (LoopVar loop = 0; loop < loops; loop++) {
            uint64_t tmpStride = loop * n0DstStride * dstTypeSize;
            CumOperationTool<T0, T1, tileH, tileW, dstTypeSize, is_sum>(dst, src, tmpStride);
        }
        return;
    } else if constexpr (axis == 2) {
        constexpr auto dst2RawShape = Std::tuple_element<shapeSize - 2, typename T0::TileShape>::type::value;
        constexpr auto dst3RawShape = Std::tuple_element<shapeSize - 3, typename T0::TileShape>::type::value;
        constexpr int tileH = dst3RawShape;
        constexpr int tileW = dst2RawShape * dst1RawShape;

        for (LoopVar j = 0; j < n0DstShape; j++) {
            for (LoopVar k = 0; k < n1DstShape; k++) {
                uint64_t tmpStride = (j * n0DstStride + k * n1DstStride) * dstTypeSize;
                CumOperationTool<T0, T1, tileH, tileW, dstTypeSize, is_sum>(dst, src, tmpStride);
            }
        }
        return;
    } else if constexpr (axis == 3) {
        constexpr auto dst2RawShape = Std::tuple_element<shapeSize - 2, typename T0::TileShape>::type::value;
        constexpr int tileH = dst2RawShape;
        constexpr int tileW = dst1RawShape;

        for (LoopVar m = 0; m < n0DstShape; m++) {
            for (LoopVar j = 0; j < n1DstShape; j++) {
                for (LoopVar k = 0; k < n2DstShape; k++) {
                    uint64_t tmpStride = (m * n0DstStride + j * n1DstStride + k * n2DstStride) * dstTypeSize;
                    CumOperationTool<T0, T1, tileH, tileW, dstTypeSize, is_sum>(dst, src, tmpStride);
                }
            }
        }
        return;
    } else {
        set_flag(PIPE_V, PIPE_S, EVENT_ID7);
        wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
        auto srcAddr = (__ubuf__ typename T1::Type*)((uint64_t)(src.GetAddr()));
        auto dstAddr = (__ubuf__ typename T0::Type*)((uint64_t)(dst.GetAddr()));

        for (LoopVar n = 0; n < n0DstShape; n++) {
            for (LoopVar j = 0; j < n1DstShape; j++) {
                for (LoopVar k = 0; k < n2DstShape; k++) {
                    for (LoopVar m = 0; m < n3DstShape; m++) {
                        int tmpStride = n * n0DstStride + j * n1DstStride + k * n2DstStride + m * n3DstStride;
                        dstAddr[tmpStride] = srcAddr[tmpStride];
                        for (LoopVar i = 1; i < n4DstShape;) {
                            CumOperationScalarTool<T0, is_sum>(dstAddr, srcAddr, tmpStride, n4DstShape, i);
                            i = i * 2;
                        }
                    }
                }
            }
        }
        set_flag(PIPE_S, PIPE_V, EVENT_ID7);
        wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
    }
}
#endif