/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file lerp_f16_nddma_with_loops.h
 * \brief
 */
#ifndef ASCENDC_LERP_F16_NDDMA_WITH_LOOPS_H_
#define ASCENDC_LERP_F16_NDDMA_WITH_LOOPS_H_

#include "kernel_operator.h"
#include "atvoss/util/broadcast_utils.h"

namespace Lerp {
using AscendC::GlobalTensor;
using AscendC::LocalTensor;
using AscendC::TBuf;
using AscendC::TPipe;
using AscendC::TQue;
using AscendC::MicroAPI::MaskReg;
using AscendC::MicroAPI::RegTensor;

// start is float16, end is float16, weight is float16, y is float16, max dims in ub is 8 and nddma needs loops
class LerpF16NddmaWithLoops {
   public:
    __aicore__ inline LerpF16NddmaWithLoops(){};
    __aicore__ inline void Init(GM_ADDR start, GM_ADDR end, GM_ADDR weight, GM_ADDR y, GM_ADDR workspace,
                                const LerpTilingData* tilingDataPtr, TPipe* pipePtr)
    {
        pipePtr_ = pipePtr;
        tilingDataPtr_ = tilingDataPtr;
        inputGmStart_.SetGlobalBuffer((__gm__ half*)start);
        inputGmEnd_.SetGlobalBuffer((__gm__ half*)end);
        inputGmWeight_.SetGlobalBuffer((__gm__ half*)weight);
        outputGmY_.SetGlobalBuffer((__gm__ half*)y);
        constexpr int64_t DOUBLE_BUFFER = 2;
        int64_t BUFFER_SIZE_0 = tilingDataPtr_->elemNum * sizeof(half);
        pipePtr_->InitBuffer(queIn0_, DOUBLE_BUFFER, BUFFER_SIZE_0);
        pipePtr_->InitBuffer(queIn1_, DOUBLE_BUFFER, BUFFER_SIZE_0);
        pipePtr_->InitBuffer(queIn2_, DOUBLE_BUFFER, BUFFER_SIZE_0);
        pipePtr_->InitBuffer(queOut0_, DOUBLE_BUFFER, BUFFER_SIZE_0);
    }

    __aicore__ inline void Process()
    {
        int64_t ubLoopNum = AscendC::GetBlockIdx() == AscendC::GetBlockNum() - 1 ? tilingDataPtr_->blockTail
                                                                                 : tilingDataPtr_->blockFormer;
        int64_t axesIndices[Ops::Base::BROADCAST_MAX_DIMS] = {0};
        Ops::Base::BroadcastGetAxesIndices(axesIndices, tilingDataPtr_->blockFormer * AscendC::GetBlockIdx(),
                                tilingDataPtr_->outputDims, tilingDataPtr_->ubSplitAxis,
                                tilingDataPtr_->dimProductBeforeUbInner);
        for (int64_t ubLoopIdx = 0; ubLoopIdx < ubLoopNum; ubLoopIdx += 1) {
            if (ubLoopIdx != 0) {
                Ops::Base::BroadcastUpdateAxesIndices(axesIndices, tilingDataPtr_->outputDims, tilingDataPtr_->ubSplitAxis,
                                           tilingDataPtr_->ubOuter);
            }
            int64_t ubSplitSize = axesIndices[tilingDataPtr_->ubSplitAxis] == tilingDataPtr_->ubOuter - 1
                                      ? tilingDataPtr_->ubTail
                                      : tilingDataPtr_->ubFormer;
            CopyIn0(ubSplitSize, axesIndices, ubLoopIdx);
            CopyIn1(ubSplitSize, axesIndices, ubLoopIdx);
            CopyIn2(ubSplitSize, axesIndices, ubLoopIdx);
            Compute3(ubSplitSize, axesIndices, ubLoopIdx);
            CopyOut4(ubSplitSize, axesIndices, ubLoopIdx);
        }
    }

   private:
    __aicore__ inline void CopyIn0(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
                                   int64_t ubLoopIdx)
    {
        bufferIn0_ = queIn0_.AllocTensor<half>();
        if ((tilingDataPtr_->input2Strides[tilingDataPtr_->ubSplitAxis] != 0) ||
            (ubLoopIdx <= 1 ||
             (AscendC::GetBlockIdx() * tilingDataPtr_->blockFormer + ubLoopIdx) % tilingDataPtr_->ubOuter <= 1)) {
            Ops::Base::BroadcastNddmaWithLoop(inputGmWeight_, bufferIn0_, tilingDataPtr_->outputDims,
                                   tilingDataPtr_->outputStrides, tilingDataPtr_->input2Strides, axesIndices,
                                   tilingDataPtr_->ubSplitAxis, tilingDataPtr_->shapeLen, ubSplitSize,
                                   tilingDataPtr_->ubFormer);
        }
        queIn0_.EnQue<half>(bufferIn0_);
    }

    __aicore__ inline void CopyIn1(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
                                   int64_t ubLoopIdx)
    {
        bufferIn1_ = queIn1_.AllocTensor<half>();
        if ((tilingDataPtr_->input1Strides[tilingDataPtr_->ubSplitAxis] != 0) ||
            (ubLoopIdx <= 1 ||
             (AscendC::GetBlockIdx() * tilingDataPtr_->blockFormer + ubLoopIdx) % tilingDataPtr_->ubOuter <= 1)) {
            Ops::Base::BroadcastNddmaWithLoop(inputGmEnd_, bufferIn1_, tilingDataPtr_->outputDims, tilingDataPtr_->outputStrides,
                                   tilingDataPtr_->input1Strides, axesIndices, tilingDataPtr_->ubSplitAxis,
                                   tilingDataPtr_->shapeLen, ubSplitSize, tilingDataPtr_->ubFormer);
        }
        queIn1_.EnQue<half>(bufferIn1_);
    }

    __aicore__ inline void CopyIn2(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
                                   int64_t ubLoopIdx)
    {
        bufferIn2_ = queIn2_.AllocTensor<half>();
        if ((tilingDataPtr_->input0Strides[tilingDataPtr_->ubSplitAxis] != 0) ||
            (ubLoopIdx <= 1 ||
             (AscendC::GetBlockIdx() * tilingDataPtr_->blockFormer + ubLoopIdx) % tilingDataPtr_->ubOuter <= 1)) {
            Ops::Base::BroadcastNddmaWithLoop(inputGmStart_, bufferIn2_, tilingDataPtr_->outputDims, tilingDataPtr_->outputStrides,
                                   tilingDataPtr_->input0Strides, axesIndices, tilingDataPtr_->ubSplitAxis,
                                   tilingDataPtr_->shapeLen, ubSplitSize, tilingDataPtr_->ubFormer);
        }
        queIn2_.EnQue<half>(bufferIn2_);
    }

    __aicore__ inline void Compute3(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
                                    int64_t ubLoopIdx)
    {
        bufferIn0_ = queIn0_.DeQue<half>();
        bufferIn1_ = queIn1_.DeQue<half>();
        bufferIn2_ = queIn2_.DeQue<half>();
        bufferOut0_ = queOut0_.AllocTensor<half>();
        __VEC_SCOPE__
        {
            RegTensor<half> vreg0;
            RegTensor<float> vreg1;
            RegTensor<half> vreg2;
            RegTensor<float> vreg3;
            RegTensor<half> vreg4;
            RegTensor<float> vreg5;
            RegTensor<float> vreg6;
            RegTensor<float> vreg7;
            RegTensor<float> vreg8;
            RegTensor<float> vreg9;
            RegTensor<half> vreg10;
            MaskReg preg0;
            uint32_t size = ubSplitSize * tilingDataPtr_->outputStrides[tilingDataPtr_->ubSplitAxis];
            MaskReg preg2;
            preg0 = AscendC::MicroAPI::CreateMask<float>();
            AscendC::MicroAPI::Duplicate<float, AscendC::MicroAPI::MaskMergeMode::ZEROING, float>(
                vreg7, static_cast<float>(1.0), preg0);
            uint16_t vfLoopNum = (ubSplitSize * tilingDataPtr_->outputStrides[tilingDataPtr_->ubSplitAxis] +
                                  (AscendC::VECTOR_REG_WIDTH / 4) - 1) /
                                 (AscendC::VECTOR_REG_WIDTH / 4);
            __local_mem__ half* bufferIn0Addr = (__local_mem__ half*)bufferIn0_.GetPhyAddr();
            __local_mem__ half* bufferIn1Addr = (__local_mem__ half*)bufferIn1_.GetPhyAddr();
            __local_mem__ half* bufferIn2Addr = (__local_mem__ half*)bufferIn2_.GetPhyAddr();
            __local_mem__ half* bufferOut0Addr = (__local_mem__ half*)bufferOut0_.GetPhyAddr();
            for (uint16_t i = 0; i < vfLoopNum; i++) {
                preg0 = AscendC::MicroAPI::UpdateMask<float>(size);
                AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16>(
                    vreg0, bufferIn0Addr + i * (AscendC::VECTOR_REG_WIDTH / 4));
                AscendC::MicroAPI::Cast<float, half, castTrait0>(vreg1, vreg0, preg0);
                AscendC::MicroAPI::CompareScalar<float, AscendC::CMPMODE::LT>(preg2, vreg1, static_cast<float>(0.5),
                                                                              preg0);
                AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16>(
                    vreg2, bufferIn1Addr + i * (AscendC::VECTOR_REG_WIDTH / 4));
                AscendC::MicroAPI::Cast<float, half, castTrait0>(vreg3, vreg2, preg0);
                AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::LoadDist::DIST_UNPACK_B16>(
                    vreg4, bufferIn2Addr + i * (AscendC::VECTOR_REG_WIDTH / 4));
                AscendC::MicroAPI::Cast<float, half, castTrait0>(vreg5, vreg4, preg0);
                AscendC::MicroAPI::Sub<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg6, vreg3, vreg5, preg0);
                AscendC::MicroAPI::MulAddDst<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg5, vreg1, vreg6,
                                                                                               preg0);
                AscendC::MicroAPI::Sub<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg8, vreg1, vreg7, preg0);
                AscendC::MicroAPI::MulAddDst<float, AscendC::MicroAPI::MaskMergeMode::ZEROING>(vreg3, vreg8, vreg6,
                                                                                               preg0);
                AscendC::MicroAPI::Select<float>(vreg9, vreg5, vreg3, preg2);
                AscendC::MicroAPI::Cast<half, float, castTrait1>(vreg10, vreg9, preg0);
                AscendC::MicroAPI::DataCopy<half, AscendC::MicroAPI::StoreDist::DIST_PACK_B32>(
                    bufferOut0Addr + i * (AscendC::VECTOR_REG_WIDTH / 4), vreg10, preg0);
            }
        }
        queIn0_.FreeTensor(bufferIn0_);
        queIn1_.FreeTensor(bufferIn1_);
        queIn2_.FreeTensor(bufferIn2_);
        queOut0_.EnQue<half>(bufferOut0_);
    }

    __aicore__ inline void CopyOut4(int64_t ubSplitSize, const int64_t (&axesIndices)[Ops::Base::BROADCAST_MAX_DIMS],
                                    int64_t ubLoopIdx)
    {
        bufferOut0_ = queOut0_.DeQue<half>();
        AscendC::DataCopyExtParams dataCopyExtParams;
        dataCopyExtParams.blockCount = 1;
        dataCopyExtParams.blockLen =
            ubSplitSize * tilingDataPtr_->outputStrides[tilingDataPtr_->ubSplitAxis] * sizeof(half);
        int64_t gmOffset = Ops::Base::BroadcastGetGmOffset(axesIndices, tilingDataPtr_->outputStrides, tilingDataPtr_->ubSplitAxis,
                                                tilingDataPtr_->ubFormer);
        AscendC::DataCopyPad(outputGmY_[gmOffset], bufferOut0_[0], dataCopyExtParams);
        queOut0_.FreeTensor(bufferOut0_);
    }

   private:
    TPipe* pipePtr_;
    const LerpTilingData* tilingDataPtr_;
    GlobalTensor<half> inputGmStart_;
    GlobalTensor<half> inputGmEnd_;
    GlobalTensor<half> inputGmWeight_;
    GlobalTensor<half> outputGmY_;
    TQue<AscendC::QuePosition::VECIN, 1> queIn0_;
    TQue<AscendC::QuePosition::VECIN, 1> queIn1_;
    TQue<AscendC::QuePosition::VECIN, 1> queIn2_;
    TQue<AscendC::QuePosition::VECOUT, 1> queOut0_;
    LocalTensor<half> bufferIn0_;
    LocalTensor<half> bufferIn1_;
    LocalTensor<half> bufferIn2_;
    LocalTensor<half> bufferOut0_;
    constexpr static AscendC::MicroAPI::CastTrait castTrait0 = {
        AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::UNKNOWN,
        AscendC::MicroAPI::MaskMergeMode::ZEROING, AscendC::RoundMode::UNKNOWN};
    constexpr static AscendC::MicroAPI::CastTrait castTrait1 = {
        AscendC::MicroAPI::RegLayout::ZERO, AscendC::MicroAPI::SatMode::NO_SAT,
        AscendC::MicroAPI::MaskMergeMode::ZEROING, AscendC::RoundMode::CAST_RINT};
};

}  // namespace Lerp
#endif  // ASCENDC_LERP_F16_NDDMA_WITH_LOOPS_H_