/*

 * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.

 *

 */

#ifndef MSDA_H

#define MSDA_H



#include "kernel_operator.h"



using namespace AscendC;

using namespace MicroAPI;



template <typename T, typename U>

__aicore__ inline void ComputeGmOffsetVF(uint16_t taskRpt_, uint32_t numHeads_, uint32_t embedDims_,

    uint32_t baseOffset, uint32_t nextOffset, uint32_t baseCount, const LocalTensor<T> locationFloat,

    const LocalTensor<T> shapeFloat, const LocalTensor<U> offsetInt, const LocalTensor<U> gmOffset,

    const LocalTensor<U> validMaskTensor) {

    __local_mem__ T *locationFloatPtr = (__local_mem__ T *)locationFloat.GetPhyAddr();

    __local_mem__ T *locationInputsPtr =

        (__local_mem__ T *)locationFloat[2 * taskRpt_ * B32_DATA_NUM_PER_REPEAT].GetPhyAddr();

    __local_mem__ T *shapeFloatPtr = (__local_mem__ T *)shapeFloat.GetPhyAddr();

    __local_mem__ U *offsetIntPtr = (__local_mem__ U *)offsetInt.GetPhyAddr();

    __local_mem__ U *gmOffsetPtr = (__local_mem__ U *)gmOffset.GetPhyAddr();

    __local_mem__ U *validMaskPtr = (__local_mem__ U *)validMaskTensor.GetPhyAddr();



    __VEC_SCOPE__ {

        MicroAPI::RegTensor<T> locationXY1Reg, locationXY2Reg, shapeInput1Reg, shapeInput2Reg;

        MicroAPI::RegTensor<T> locationXReg, locationYReg, widthFloatReg, heightFloatReg;



        MicroAPI::RegTensor<U> offsetReg;

        MicroAPI::RegTensor<U> widthIntReg;

        MicroAPI::RegTensor<U> locationXIntReg, locationYIntReg;

        MicroAPI::RegTensor<U> gmOffset1Reg, gmOffset2Reg;

        MicroAPI::RegTensor<U> baseOffsetReg;



        MicroAPI::RegTensor<U> validMaskReg;

        MicroAPI::RegTensor<U> bilinearValidPoint1Reg, bilinearValidPoint2Reg, bilinearValidPoint3Reg,

            bilinearValidPoint4Reg;



        MicroAPI::RegTensor<T> constOffsetReg;

        MicroAPI::RegTensor<U> zeroReg;



        MicroAPI::MaskReg mask = MicroAPI::CreateMask<T, AscendC::MicroAPI::MaskPattern::ALL>();



        static constexpr AscendC::MicroAPI::CastTrait castF2ITrait = {

            MicroAPI::RegLayout::ZERO, MicroAPI::SatMode::SAT, MicroAPI::MaskMergeMode::ZEROING, RoundMode::CAST_FLOOR};

        static constexpr AscendC::MicroAPI::CastTrait castI2FTrait = {

            MicroAPI::RegLayout::ZERO, MicroAPI::SatMode::SAT, MicroAPI::MaskMergeMode::ZEROING, RoundMode::CAST_RINT};



        Duplicate(constOffsetReg, -0.5, mask);

        Duplicate(zeroReg, 0, mask);



        uint32_t taskOffset_ = taskRpt_ * B32_DATA_NUM_PER_REPEAT;

        for (uint16_t taskIdx = 0; taskIdx < taskRpt_; ++taskIdx) {

            uint32_t localOffset = taskIdx * B32_DATA_NUM_PER_REPEAT;

            MicroAPI::DataCopy(locationXY1Reg, locationInputsPtr + 2 * localOffset);

            MicroAPI::DataCopy(locationXY2Reg, locationInputsPtr + 2 * localOffset + B32_DATA_NUM_PER_REPEAT);

            MicroAPI::DataCopy(shapeInput1Reg, shapeFloatPtr + 2 * localOffset);

            MicroAPI::DataCopy(shapeInput2Reg, shapeFloatPtr + 2 * localOffset + B32_DATA_NUM_PER_REPEAT);

            MicroAPI::DataCopy(offsetReg, offsetIntPtr + localOffset);



            MicroAPI::DeInterleave(locationXReg, locationYReg, locationXY1Reg, locationXY2Reg);

            MicroAPI::DeInterleave(widthFloatReg, heightFloatReg, shapeInput1Reg, shapeInput2Reg);

            MicroAPI::FusedMulDstAdd(locationXReg, widthFloatReg, constOffsetReg, mask);

            MicroAPI::FusedMulDstAdd(locationYReg, heightFloatReg, constOffsetReg, mask);



            MicroAPI::Interleave(locationXY1Reg, locationXY2Reg, locationXReg, locationYReg);

            MicroAPI::DataCopy(locationFloatPtr + 2 * localOffset, locationXY1Reg, mask);

            MicroAPI::DataCopy(locationFloatPtr + 2 * localOffset + B32_DATA_NUM_PER_REPEAT, locationXY2Reg, mask);



            MicroAPI::Cast<U, T, castF2ITrait>(widthIntReg, widthFloatReg, mask);

            MicroAPI::Cast<U, T, castF2ITrait>(locationXIntReg, locationXReg, mask);

            MicroAPI::Cast<U, T, castF2ITrait>(locationYIntReg, locationYReg, mask);



            MicroAPI::Mul(gmOffset1Reg, locationYIntReg, widthIntReg, mask);

            MicroAPI::Add(gmOffset1Reg, gmOffset1Reg, locationXIntReg, mask);

            MicroAPI::Muls(gmOffset1Reg, gmOffset1Reg, numHeads_, mask);

            MicroAPI::Add(gmOffset1Reg, gmOffset1Reg, offsetReg, mask);



            MicroAPI::MaskReg baseMask = MicroAPI::UpdateMask<T>(baseCount);

            MicroAPI::Duplicate<U, MaskMergeMode::MERGING>(baseOffsetReg, nextOffset, mask);

            MicroAPI::Duplicate<U, MaskMergeMode::MERGING>(baseOffsetReg, baseOffset, baseMask);

            MicroAPI::Add(gmOffset1Reg, gmOffset1Reg, baseOffsetReg, mask);

            MicroAPI::Muls(gmOffset1Reg, gmOffset1Reg, embedDims_, mask);

            MicroAPI::Muls(offsetReg, widthIntReg, numHeads_ * embedDims_, mask);

            MicroAPI::Add(gmOffset2Reg, gmOffset1Reg, offsetReg, mask);

            MicroAPI::Interleave(locationXIntReg, locationYIntReg, gmOffset1Reg, gmOffset2Reg);

            MicroAPI::DataCopy(gmOffsetPtr + 2 * localOffset, locationXIntReg, mask);

            MicroAPI::DataCopy(gmOffsetPtr + 2 * localOffset + B32_DATA_NUM_PER_REPEAT, locationYIntReg, mask);



            MicroAPI::MaskReg validMask, tmpMask;

            MicroAPI::Compares<T, CMPMODE::GT>(validMask, locationXReg, -1.0f, mask);

            MicroAPI::Compares<T, CMPMODE::GT>(tmpMask, locationYReg, -1.0f, mask);

            MicroAPI::And(validMask, validMask, tmpMask, mask);

            MicroAPI::Compare<T, CMPMODE::LT>(tmpMask, locationXReg, widthFloatReg, mask);

            MicroAPI::And(validMask, validMask, tmpMask, mask);

            MicroAPI::Compare<T, CMPMODE::LT>(tmpMask, locationYReg, heightFloatReg, mask);

            MicroAPI::And(validMask, validMask, tmpMask, mask);



            MicroAPI::MaskReg leftMask, rightMask, bottomMask, topMask;

            MicroAPI::Adds(widthFloatReg, widthFloatReg, -1.0f, mask);

            MicroAPI::Adds(heightFloatReg, heightFloatReg, -1.0f, mask);

            MicroAPI::Compares<T, CMPMODE::GE>(leftMask, locationXReg, 0.0f, mask);

            MicroAPI::Compares<T, CMPMODE::GE>(bottomMask, locationYReg, 0.0f, mask);

            MicroAPI::Compare<T, CMPMODE::LT>(rightMask, locationXReg, widthFloatReg, mask);

            MicroAPI::Compare<T, CMPMODE::LT>(topMask, locationYReg, heightFloatReg, mask);



            MicroAPI::Duplicate(bilinearValidPoint1Reg, 1, leftMask);

            MicroAPI::Select(bilinearValidPoint1Reg, bilinearValidPoint1Reg, zeroReg, bottomMask);

            MicroAPI::Duplicate(bilinearValidPoint2Reg, 2, rightMask);

            MicroAPI::Select(bilinearValidPoint2Reg, bilinearValidPoint2Reg, zeroReg, bottomMask);

            MicroAPI::Duplicate(bilinearValidPoint3Reg, 4, leftMask);

            MicroAPI::Select(bilinearValidPoint3Reg, bilinearValidPoint3Reg, zeroReg, topMask);

            MicroAPI::Duplicate(bilinearValidPoint4Reg, 8, rightMask);

            MicroAPI::Select(bilinearValidPoint4Reg, bilinearValidPoint4Reg, zeroReg, topMask);

            MicroAPI::Add(bilinearValidPoint1Reg, bilinearValidPoint1Reg, bilinearValidPoint2Reg, mask);

            MicroAPI::Add(bilinearValidPoint3Reg, bilinearValidPoint3Reg, bilinearValidPoint4Reg, mask);

            MicroAPI::Add(validMaskReg, bilinearValidPoint1Reg, bilinearValidPoint3Reg, mask);

            MicroAPI::Not(tmpMask, validMask, mask);

            MicroAPI::Duplicate<U, MaskMergeMode::MERGING>(validMaskReg, 0, tmpMask);



            MicroAPI::DataCopy(validMaskPtr + localOffset, validMaskReg, mask);

        }

    }

}

#endif // MSDA_H