/**

* Copyright (c) 2025 Huawei Technologies Co., Ltd.

* This program is free software, you can redistribute it and/or modify it under the terms and conditions of

* CANN Open Software License Agreement Version 2.0 (the "License").

* Please refer to the License for details. You may not use this file except in compliance with the License.

* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

* See LICENSE in the root of the software repository for the full text of the License.

*/



/*!

 * \file kernel_operator_proposal_intf_impl.h

 * \brief

 */

#ifndef ASCENDC_MODULE_OPERATOR_PROPOSAL_INTERFACE_IMPL_H

#define ASCENDC_MODULE_OPERATOR_PROPOSAL_INTERFACE_IMPL_H

#include "kernel_tensor.h"

#include "kernel_struct_proposal.h"

#include "kernel_operator_block_sync_intf.h"



#if __NPU_ARCH__ == 1001

#include "dav_c100/kernel_operator_proposal_impl.h"

#include "dav_c100/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 2002

#include "dav_m200/kernel_operator_proposal_impl.h"

#include "dav_m200/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 2201

#include "dav_c220/kernel_operator_proposal_impl.h"

#include "dav_c220/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 3002

#include "dav_m300/kernel_operator_proposal_impl.h"

#include "dav_m300/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 3102

#include "dav_m310/kernel_operator_proposal_impl.h"

#include "dav_m310/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 3101

#include "dav_c310/kernel_operator_proposal_impl.h"

#include "dav_c310/kernel_operator_vec_gather_mask_impl.h"

#elif (__NPU_ARCH__ == 5102)

#include "dav_m510/kernel_operator_proposal_impl.h"

#include "dav_m510/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 3003

#include "dav_l300/kernel_operator_proposal_impl.h"

#include "dav_l300/kernel_operator_vec_gather_mask_impl.h"

#elif __NPU_ARCH__ == 3113

#include "dav_l311/kernel_operator_proposal_impl.h"

#include "dav_l311/kernel_operator_vec_gather_mask_impl.h"

#endif



#if ASCENDC_CPU_DEBUG

#include "kernel_check.h"

#endif

namespace AscendC {

// for src is fp32, index store in label

// for src is fp16, index store in label + y1, and using GatherMask do extract

constexpr int32_t REGION_PROPOSAL_LABEL_POSITION = 5;

constexpr int32_t REGION_PROPOSAL_Y1_POSITION = 1;

constexpr uint8_t GATHER_MASK_MODE_FOR_INDEX_EVEN = 1;

constexpr uint8_t GATHER_MASK_MODE_FOR_INDEX_ODD = 2;

// gahter mask mode 4 is 00100010: fetch 2nd and 6th elems for each 8 elems

constexpr uint8_t GATHER_MASK_MODE_FOR_EXTRACT_INDEX = 4;

constexpr int32_t REGION_PROPOSAL_SCORE_POSITION = 4;



#pragma begin_pipe(V)

/* **************************************** MrgSort4 ****************************************** */

/*

 * @ingroup MrgSort4

 * @brief Arrange and merge up to four arranged potential queues into one queue

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor list

 * @param [in] filter input LocalTensor

 * @param [in] Params.elementLengths length of proposal list

 * @param [in] Params.ifExhaustedSuspension judge whether to stop after a queue is exhausted

 * @param [in] Params.validBit judge value is valid or not

 * @param [in] Params.repeatTimes repeat times

 */

template <typename T>

__aicore__ inline void MrgSort4(const LocalTensor<T>& dst, const MrgSortSrcList<T>& src,

    const MrgSort4Info& params)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in MrgSort4, current api support dtype combination is "

        "src and dst both: half / float");});

    for (int8_t i = 0; i < MRG_SORT_ELEMENT_LEN; ++i) {

        ASCENDC_CHECK_VALUE_RANGE(params.elementLengths[i], 0, 4095, "elementLengths", "MrgSort4");

    }

    ASCENDC_ASSERT((params.validBit == 3 || params.validBit == 7 || params.validBit == 15),

        { KERNEL_LOG(KERNEL_ERROR, "Failed to check validBit value in MrgSort4, its valid value is 3 / 7 / 15"); });

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dst, src, params, sizeof(PrimType) * Internal::REGION_PROPOSAL_ELEMENT_NUM, "MrgSort4")) {

        ASCENDC_REPORT_CHECK_ERROR("MrgSort4", KernelFuncType::NONE_MODE);

    }

#endif

    uint64_t config = 0;

    config |= (params.repeatTimes & 0xFF);

    config |= (uint64_t(params.elementLengths[0] & 0xFFF) << 8);

    config |= (uint64_t(params.elementLengths[1] & 0xFFF) << 20);

    config |= (uint64_t(params.elementLengths[2] & 0xFFF) << 32);

    config |= (uint64_t(params.elementLengths[3] & 0xFFF) << 44);

    config |= (uint64_t(params.ifExhaustedSuspension & 0x1) << 59);

    config |= (uint64_t(params.validBit & 0xF) << 60);



    __ubuf__ PrimType *addrArray[MRG_SORT_ELEMENT_LEN] = {(__ubuf__ PrimType *)src.src1.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src2.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src3.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src4.GetPhyAddr()};

    Vmrgsort4Cal((__ubuf__ PrimType*)dst.GetPhyAddr(), addrArray, config);

}



/* **************************************** RpSort16 ****************************************** */

/*

 * @ingroup RpSort16

 * @brief Sort them according to the score field in the Region Proposals

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] repeatTime repeat times

 */

template <typename T>

__aicore__ inline void RpSort16(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const int32_t repeatTime)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in RpSort16, current api support dtype combination is "

        "src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "RpSort16");

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dst, src, repeatTime, "RpSort16")) {

        ASCENDC_REPORT_CHECK_ERROR("RpSort16", KernelFuncType::NONE_MODE);

    }

#endif

    struct ProposalIntriParams repeatParams;

    repeatParams.repeat = repeatTime;

    VbitsortCal((__ubuf__ PrimType*)dst.GetPhyAddr(), (__ubuf__ PrimType*)src.GetPhyAddr(), repeatParams);

}



/* **************************************** MrgSort ****************************************** */

/*

 * @ingroup MrgSort

 * @brief Arrange and merge up to four arranged potential queues into one queue

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor list

 * @param [in] filter input LocalTensor

 * @param [in] Params.elementLengths length of proposal list

 * @param [in] Params.ifExhaustedSuspension judge whether to stop after a queue is exhausted

 * @param [in] Params.validBit judge value is valid or not

 * @param [in] Params.repeatTimes repeat times

 */

template <typename T>

__aicore__ inline void MrgSort(const LocalTensor<T>& dst, const MrgSortSrcList<T>& src,

    const MrgSort4Info& params)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in MrgSort, current api support dtype combination is "

        "src and dst both: half / float");});

    for (int8_t i = 0; i < MRG_SORT_ELEMENT_LEN; ++i) {

        ASCENDC_CHECK_VALUE_RANGE(params.elementLengths[i], 0, 4095, "elementLengths", "MrgSort");

    }

    ASCENDC_ASSERT((params.validBit == 3 || params.validBit == 7 || params.validBit == 15),

        { KERNEL_LOG(KERNEL_ERROR, "Failed to check validBit value in MrgSort, its valid value is 3 / 7 / 15"); });

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dst, src, params, Internal::REGION_PROPOSAL_ELEMENT_NUM, "MrgSort")) {

        ASCENDC_REPORT_CHECK_ERROR("MrgSort", KernelFuncType::NONE_MODE);

    }

#endif

    uint64_t config = 0;

    config |= (params.repeatTimes & 0xFF);                          // Xt[7:0]: repeat time

    config |= (uint64_t(params.validBit & 0xF) << 8);               // Xt[11:8]: 4-bit mask signal

    config |= (uint64_t(params.ifExhaustedSuspension & 0x1) << 12); // Xt[12]: 1-enable input list exhausted suspension



    uint64_t src1 = 0;

    src1 |= (uint64_t(params.elementLengths[0] & 0xFFFF));

    src1 |= (uint64_t(params.elementLengths[1] & 0xFFFF) << 16);

    src1 |= (uint64_t(params.elementLengths[2] & 0xFFFF) << 32);

    src1 |= (uint64_t(params.elementLengths[3] & 0xFFFF) << 48);



#ifndef ASCENDC_CPU_DEBUG

    __ubuf__ PrimType *addrArray[MRG_SORT_ELEMENT_LEN] = {(__ubuf__ PrimType *)src.src1.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src2.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src3.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src4.GetPhyAddr()};

#else

    __ubuf__ PrimType *addrArray[MRG_SORT_ELEMENT_LEN] = {(__ubuf__ PrimType *)src.src1.GetPhyAddr(),

        (__ubuf__ PrimType *)src.src2.GetPhyAddr(),

        (params.validBit & 0x4) ? (__ubuf__ PrimType *)src.src3.GetPhyAddr() : nullptr,

        (params.validBit & 0x8) ? (__ubuf__ PrimType *)src.src4.GetPhyAddr() : nullptr};

#endif



    Vmrgsort4Cal((__ubuf__ PrimType*)dst.GetPhyAddr(), addrArray, src1, config);

}



/* **************************************** Sort32 ****************************************** */

/*

 * @ingroup Sort32

 * @brief Sort 32 elements

 * @param [out] dst output LocalTensor

 * @param [in] src0 input LocalTensor

 * @param [in] src1 input LocalTensor

 * @param [in] repeatTime repeat times

 */

template <typename T>

__aicore__ inline void Sort32(const LocalTensor<T>& dst, const LocalTensor<T>& src0,

    const LocalTensor<uint32_t>& src1, const int32_t repeatTime)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in Sort32, current api support dtype combination is "

        "src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "Sort32");

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dst, src0, src1, repeatTime, "Sort32")) {

        ASCENDC_REPORT_CHECK_ERROR("Sort32", KernelFuncType::NONE_MODE);

    }

#endif

    struct ProposalIntriParams repeatParams;

    repeatParams.repeat = repeatTime;

    VbitsortCal((__ubuf__ PrimType *)dst.GetPhyAddr(), (__ubuf__ PrimType *)src0.GetPhyAddr(),

        (__ubuf__ uint32_t *)src1.GetPhyAddr(), repeatParams);

}



/* **************************************** ProposalConcat ****************************************** */

/*

 * @ingroup ProposalConcat

 * @brief Combine continuous elements into corresponding positions in the Region Proposal

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] repeatTime repeat times

 * @param [in] modeNumbe Position parameter

 */

template <typename T>

__aicore__ inline void ProposalConcat(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const int32_t repeatTime, const int32_t modeNumber)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()), {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in ProposalConcat,"

        " current api support dtype combination is src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "ProposalConcat");

    ASCENDC_CHECK_VALUE_RANGE(modeNumber, 0, 5, "modeNumber", "ProposalConcat");

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dst, src, repeatTime, "ProposalConcat")) {

        ASCENDC_REPORT_CHECK_ERROR("ProposalConcat", KernelFuncType::NONE_MODE);

    }

#endif

    struct ProposalIntriParams repeatParams;

    repeatParams.repeat = repeatTime;

    repeatParams.modeNumber = modeNumber;

    VconcatCal((__ubuf__ PrimType *)dst.GetPhyAddr(), (__ubuf__ PrimType *)src.GetPhyAddr(), repeatParams);

}



/* **************************************** ProposalExtract ****************************************** */

/*

 * @ingroup ProposalExtract

 * @brief ProposalExtract and rearrange the individual elements in the corresponding position from the Region Proposals

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] repeatTime repeat times

 * @param [in] modeNumbe Position parameter

 */

template <typename T>

__aicore__ inline void ProposalExtract(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const int32_t repeatTime, const int32_t modeNumber)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()), {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in "

        "ProposalExtract, current api support dtype combination is src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "ProposalExtract");

    ASCENDC_CHECK_VALUE_RANGE(modeNumber, 0, 5, "modeNumber", "ProposalExtract");

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dst, src, repeatTime, "ProposalExtract")) {

        ASCENDC_REPORT_CHECK_ERROR("ProposalExtract", KernelFuncType::NONE_MODE);

    }

#endif

    struct ProposalIntriParams repeatParams;

    repeatParams.repeat = repeatTime;

    repeatParams.modeNumber = modeNumber;

    VextractCal((__ubuf__ PrimType *)dst.GetPhyAddr(), (__ubuf__ PrimType *)src.GetPhyAddr(), repeatParams);

}



/* **************************************** Concat ****************************************** */

/*

 * @ingroup Concat

 * @brief Combine continuous elements into corresponding positions

 * @param [out] concat output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] tmp tmp buffer

 * @param [in] repeatTime repeat times

 */

template <typename T>

__aicore__ inline void Concat(LocalTensor<T>& concat, const LocalTensor<T>& src,

    const LocalTensor<T>& tmp, const int32_t repeatTime)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()), {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in Concat, "

        "current api support dtype combination is src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "Concat");

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) ||                        \

    (__NPU_ARCH__ == 3002) || (__NPU_ARCH__ == 3102) ||                        \

    (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || (__NPU_ARCH__ == 3101))

    concat = src;

#elif (__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002)

    ProposalConcat(tmp, src, repeatTime, REGION_PROPOSAL_SCORE_POSITION);

    concat = tmp;

#endif

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(concat, src, tmp, repeatTime, "Concat")) {

        ASCENDC_REPORT_CHECK_ERROR("Concat", KernelFuncType::NONE_MODE);

    }

#endif

}



/* **************************************** Extract ****************************************** */

/*

 * @ingroup Extract

 * @brief Extract and rearrange the individual elements in the corresponding position

 * @param [out] dstValue output LocalTensor

 * @param [in] dstIndex output LocalTensor

 * @param [in] sorted input LocalTensor

 * @param [in] repeatTime repeat times

 */

template <typename T>

__aicore__ inline void Extract(const LocalTensor<T>& dstValue, const LocalTensor<uint32_t>& dstIndex,

    const LocalTensor<T>& sorted, const int32_t repeatTime)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()), {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in Extract, "

        "current api support dtype combination is src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "Extract");

#if ASCENDC_CPU_DEBUG

    if (!CheckFunProposal(dstValue, sorted, dstIndex, repeatTime, "Extract")) {

        ASCENDC_REPORT_CHECK_ERROR("Extract", KernelFuncType::NONE_MODE);

    }

#endif

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))

    ExtractImpl((__ubuf__ PrimType *)dstValue.GetPhyAddr(), (__ubuf__ uint32_t *)dstIndex.GetPhyAddr(),

        (__ubuf__ PrimType *)sorted.GetPhyAddr(), repeatTime);

#elif defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) || (__NPU_ARCH__ == 3102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || (__NPU_ARCH__ == 3101))

    uint64_t rsvdCnt;

    if constexpr (Std::is_same<PrimType, half>::value) {

        constexpr uint8_t gatherMaskPattern3 = 3;

        constexpr uint8_t gatherMaskPattern2 = 2;

        GatherMaskCal((__ubuf__ PrimType *)dstValue.GetPhyAddr(), (__ubuf__ PrimType *)sorted.GetPhyAddr(),

            gatherMaskPattern3, false, static_cast<uint32_t>(0), { 1, static_cast<uint16_t>(repeatTime), DEFAULT_REPEAT_STRIDE, 0 }, rsvdCnt);

        PipeBarrier<PIPE_V>();

        GatherMaskCal((__ubuf__ uint32_t *)dstIndex.GetPhyAddr(), (__ubuf__ uint32_t *)sorted.GetPhyAddr(),

            gatherMaskPattern2, false, static_cast<uint32_t>(0), { 1, static_cast<uint16_t>(repeatTime * 2), 8, 0 }, rsvdCnt);

    } else {

        constexpr uint8_t gatherMaskPattern1 = 1;

        constexpr uint8_t gatherMaskPattern2 = 2;

        GatherMaskCal((__ubuf__ PrimType *)dstValue.GetPhyAddr(), (__ubuf__ PrimType *)sorted.GetPhyAddr(),

            gatherMaskPattern1, false, static_cast<uint32_t>(0), { 1, static_cast<uint16_t>(repeatTime), DEFAULT_REPEAT_STRIDE, 0 }, rsvdCnt);

        PipeBarrier<PIPE_V>();

        GatherMaskCal((__ubuf__ uint32_t *)dstIndex.GetPhyAddr(), (__ubuf__ uint32_t *)sorted.GetPhyAddr(),

            gatherMaskPattern2, false, static_cast<uint32_t>(0), { 1, static_cast<uint16_t>(repeatTime), 8, 0 }, rsvdCnt);

    }



#elif (__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002)

    ProposalExtract(dstValue, sorted, repeatTime, REGION_PROPOSAL_SCORE_POSITION);

    if (dstIndex.GetSize() != 0) {

        PipeBarrier<PIPE_V>();

        if constexpr (Std::is_same<PrimType, half>::value) {

            uint64_t rsvdCnt;

            GatherMaskCal((__ubuf__ PrimType *)dstIndex.GetPhyAddr(), (__ubuf__ PrimType *)sorted.GetPhyAddr(),

                GATHER_MASK_MODE_FOR_EXTRACT_INDEX, false, static_cast<uint32_t>(0),

                {1, static_cast<uint16_t>(repeatTime), DEFAULT_REPEAT_STRIDE, 0}, rsvdCnt);

        } else {

            ProposalExtract(dstIndex.ReinterpretCast<T>(), sorted, repeatTime,

                            REGION_PROPOSAL_LABEL_POSITION);

        }

    }

#endif

}



/* **************************************** MrgSort ****************************************** */

/*

 * @ingroup MrgSort

 * @brief Arrange and merge up to four arranged potential queues into one queue

 * @param [out] dst output LocalTensor

 * @param [in] sortList input LocalTensor list

 * @param [in] elementCountList input LocalTensor list length

 * @param [in] sortedNum output sorted numbers

 * @param [in] validBit input valid bit

 * @param [in] repeatTime repeat times

 */

template <typename T, bool isExhaustedSuspension>

__aicore__ inline void MrgSort(const LocalTensor<T>& dst, const MrgSortSrcList<T>& sortList,

    const uint16_t elementCountList[4], uint32_t sortedNum[4], uint16_t validBit, const int32_t repeatTime)

{

    using PrimType = PrimT<T>;

#if (__NPU_ARCH__ != 5102)

    if ASCEND_IS_AIC {

        return;

    }

#endif

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in MrgSort, current api support dtype combination is "

        "src and dst both: half / float");});

    MrgSort4Info mrgSortInfo(elementCountList, isExhaustedSuspension, validBit, (uint16_t)repeatTime);

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) ||                        \

    (__NPU_ARCH__ == 3002) || (__NPU_ARCH__ == 3102) ||                        \

    (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || (__NPU_ARCH__ == 3101))

    MrgSort(dst, sortList, mrgSortInfo);

#elif (__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002)

    MrgSort4(dst, sortList, mrgSortInfo);

#endif

    if (isExhaustedSuspension) {

#if __NPU_ARCH__ == 2201 || (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113)

        constexpr uint32_t validBitMask = 0xFFFF;

        constexpr uint32_t shiftBase = 16;     // register is 16 bit per num

#elif __NPU_ARCH__ == 2002

        constexpr uint32_t validBitMask = 0x1FFF;

        constexpr uint32_t shiftBase = 13;     // register is 13 bit per num

#else

        constexpr uint32_t validBitMask = 0;

        constexpr uint32_t shiftBase = 0;     // not support

#endif

        auto res = get_vms4_sr();

        sortedNum[0] = res & validBitMask;

        sortedNum[1] = (res >> shiftBase) & validBitMask;

        sortedNum[2] = (res >> (2 * shiftBase)) & validBitMask;

        sortedNum[3] = (res >> (3 * shiftBase)) & validBitMask;

    }

}



/* **************************************** Sort ****************************************** */

/*

 * @ingroup Sort

 * @brief Sort them according to the value

 * @param [out] dst output LocalTensor

 * @param [in] concat input LocalTensor

 * @param [in] index input LocalTensor

 * @param [in] tmp tmp buffer

 * @param [in] repeatTime repeat times

 */

template <typename T, bool isFullSort>

__aicore__ inline void Sort(const LocalTensor<T>& dst, const LocalTensor<T>& concat,

    const LocalTensor<uint32_t>& index, LocalTensor<T>& tmp, const int32_t repeatTime)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()), {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in Sort, current "

        "api support dtype combination is src and dst both: half / float");});

    ASCENDC_CHECK_VALUE_RANGE(repeatTime, 0, 255, "repeatTime", "Sort");

#if ASCENDC_CPU_DEBUG

    if (!CheckFuncSort<T, uint32_t, isFullSort>(dst, concat, index, tmp, repeatTime, "Sort")) {

        ASCENDC_REPORT_CHECK_ERROR("Sort", KernelFuncType::NONE_MODE);

    }

#endif

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) ||                        \

    (__NPU_ARCH__ == 3002) || (__NPU_ARCH__ == 3102) ||                        \

    (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || (__NPU_ARCH__ == 3101))

    Sort32(dst, concat, index, repeatTime);

#elif (__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002)

    if (index.GetSize() != 0) {

        if constexpr (Std::is_same<PrimType, half>::value) {

            uint64_t rsvdCnt = 0;

            // sort process 16-elem each repeat, while gatherMask process 64-elem(uint32_t) each repeat

            // repeat time for gather mask is 1/4 of sort's repeat time

            // align repeat time to 64-elem

            constexpr uint16_t sortElemPerRepeat = 16;

            constexpr uint16_t gatherElemPerRepeat = 64;

            const uint16_t gatherRepTimes = (repeatTime * sortElemPerRepeat + gatherElemPerRepeat - 1) /

                gatherElemPerRepeat;

            GatherMaskCal((__ubuf__ PrimType *)dst.GetPhyAddr(), (__ubuf__ PrimType *)index.GetPhyAddr(),

                          GATHER_MASK_MODE_FOR_INDEX_EVEN, false, static_cast<uint32_t>(0),

                          {1, gatherRepTimes, DEFAULT_REPEAT_STRIDE, 0}, rsvdCnt);

            PipeBarrier<PIPE_V>();

            ProposalConcat(concat, dst, repeatTime, REGION_PROPOSAL_Y1_POSITION);

            PipeBarrier<PIPE_V>();

            GatherMaskCal((__ubuf__ PrimType *)dst.GetPhyAddr(), (__ubuf__ PrimType *)index.GetPhyAddr(),

                         GATHER_MASK_MODE_FOR_INDEX_ODD, false, static_cast<uint32_t>(0),

                         {1, gatherRepTimes, DEFAULT_REPEAT_STRIDE, 0}, rsvdCnt);

            PipeBarrier<PIPE_V>();

            ProposalConcat(concat, dst, repeatTime, REGION_PROPOSAL_LABEL_POSITION);

        } else {

            ProposalConcat(concat, index.ReinterpretCast<T>(), static_cast<uint16_t>(repeatTime),

                           REGION_PROPOSAL_LABEL_POSITION);

        }

        PipeBarrier<PIPE_V>();

    }

    RpSort16(dst, concat, repeatTime);

#endif

    if constexpr (isFullSort) {

        PipeBarrier<PIPE_V>();

        DoFullSort(dst, concat, index, tmp, repeatTime);

    }

}



constexpr uint32_t halfSortedDataSize = 4;

constexpr uint32_t floatSortedDataSize = 2;

/* **************************************** GetSortOffset ****************************************** */

/*

 * @ingroup GetSortOffset

 * @brief get sort offset in the sorted struct

 * @param [in] elemOffset element number offer

 */

template <typename T>

__aicore__ inline uint32_t GetSortOffset(const uint32_t elemOffset)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in GetSortOffset, current api support dtype combination is "

        "half / float");});

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) ||                        \

    (__NPU_ARCH__ == 3002) || (__NPU_ARCH__ == 3102) ||                        \

    (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || (__NPU_ARCH__ == 3101))

    if constexpr (Std::is_same<PrimType, half>::value) {

        return elemOffset * halfSortedDataSize;

    } else {

        return elemOffset * floatSortedDataSize;

    }

#else

    return elemOffset * regionProposalDataSize;

#endif

}



/* **************************************** GetSortLen ****************************************** */

/*

 * @ingroup GetSortLen

 * @brief get sort length in the sorted struct

 * @param [in] elemOffset element number ocountffer

 */

template <typename T>

__aicore__ inline uint32_t GetSortLen(const uint32_t elemCount)

{

    using PrimType = PrimT<T>;

    ASCENDC_ASSERT((SupportType<PrimType, half, float>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in GetSortLen, current api support dtype combination is "

        "half / float");});

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) ||                        \

    (__NPU_ARCH__ == 3002) || (__NPU_ARCH__ == 3102) ||                        \

    (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || (__NPU_ARCH__ == 3101))

    if constexpr (Std::is_same<PrimType, half>::value) {

        return elemCount * halfSortedDataSize;

    } else {

        return elemCount * floatSortedDataSize;

    }

#else

    return elemCount * regionProposalDataSize;

#endif

}

#pragma end_pipe

__aicore__ inline __inout_pipe__(S) void GetMrgSortResult(

    uint16_t &mrgSortList1, uint16_t &mrgSortList2, uint16_t &mrgSortList3, uint16_t &mrgSortList4)

{

#if __NPU_ARCH__ == 2201

    if (g_coreType == AIC) {

        return;

    }

#endif

    GetMrgSortResultImpl(mrgSortList1, mrgSortList2, mrgSortList3, mrgSortList4);

}

} // namespace AscendC

#endif // ASCENDC_MODULE_OPERATOR_PROPOSAL_INTERFACE_IMPL_H