MrgSort
产品支持情况
功能说明
将已经排好序的最多4条队列,合并排列成1条队列,结果按照score域由大到小排序,排布方式如下:
Ascend 950PR/Ascend 950DT采用方式一。
Atlas A3 训练系列产品/Atlas A3 推理系列产品采用方式一。
Atlas A2 训练系列产品/Atlas A2 推理系列产品采用方式一。
-
排布方式一:
MrgSort处理的数据一般是经过Sort处理后的数据,也就是Sort接口的输出,队列的结构如下所示:
-
数据类型为float,每个结构占据8Bytes。

-
数据类型为half,每个结构也占据8Bytes,中间有2Bytes保留。

-
-
排布方式二:Region Proposal排布
输入输出数据均为Region Proposal,具体请参见Sort中的排布方式二。
函数原型
template <typename T, bool isExhaustedSuspension = false>
__aicore__ inline void MrgSort(const LocalTensor<T> &dst, const MrgSortSrcList<T> &sortList, const uint16_t elementCountList[4], uint32_t sortedNum[4], uint16_t validBit, const int32_t repeatTime)
参数说明
表 1 模板参数说明
|
Ascend 950PR/Ascend 950DT,支持的数据类型为:half、float。 |
|
某条队列耗尽(即该队列已经全部排序到目的操作数)后,是否需要停止合并。类型为bool,参数取值如下:
|
表 2 接口参数说明
|
类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 |
||
源操作数,支持2-4个队列,并且每个队列都已经排好序,类型为MrgSortSrcList结构体,具体请参考表3。MrgSortSrcList中传入要合并的队列。 template <typename T>
struct MrgSortSrcList {
LocalTensor<T> src1;
LocalTensor<T> src2;
LocalTensor<T> src3; // 当要合并的队列个数小于3,可以为空tensor
LocalTensor<T> src4; // 当要合并的队列个数小于4,可以为空tensor
};
|
||
四个源队列的长度(排序方式一:8Bytes结构的数目,排序方式二:16*sizeof(T)Bytes结构的数目),类型为长度为4的uint16_t数据类型的数组,理论上每个元素取值范围[0, 4095],但不能超出UB的存储空间。 |
||
表 3 MrgSortSrcList参数说明
|
类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Ascend 950PR/Ascend 950DT,支持的数据类型为:half、float。 |
||
|
类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Ascend 950PR/Ascend 950DT,支持的数据类型为:half、float。 |
||
|
类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Ascend 950PR/Ascend 950DT,支持的数据类型为:half、float。 |
||
|
类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Ascend 950PR/Ascend 950DT,支持的数据类型为:half、float。 |
返回值说明
无
约束说明
- 当存在score[i]与score[j]相同时,如果i>j,则score[j]将首先被选出来,排在前面,即index的顺序与输入顺序一致。
- 每次迭代内的数据会进行排序,不同迭代间的数据不会进行排序。
- 操作数地址对齐要求请参见通用地址对齐约束。
调用示例
-
处理128个half类型数据。
该样例适用于:
Ascend 950PR/Ascend 950DT
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas A3 训练系列产品/Atlas A3 推理系列产品
#include "kernel_operator.h" template <typename T> class FullSort { public: __aicore__ inline FullSort() {} __aicore__ inline void Init(__gm__ uint8_t *srcValueGm, __gm__ uint8_t *srcIndexGm, __gm__ uint8_t *dstValueGm, __gm__ uint8_t *dstIndexGm) { concatRepeatTimes = elementCount / 16; inBufferSize = elementCount * sizeof(uint32_t); outBufferSize = elementCount * sizeof(uint32_t); calcBufferSize = elementCount * 8; tmpBufferSize = elementCount * 8; sortedLocalSize = elementCount * 4; sortRepeatTimes = elementCount / 32; extractRepeatTimes = elementCount / 32; sortTmpLocalSize = elementCount * 4; valueGlobal.SetGlobalBuffer((__gm__ T *)srcValueGm); indexGlobal.SetGlobalBuffer((__gm__ uint32_t *)srcIndexGm); dstValueGlobal.SetGlobalBuffer((__gm__ T *)dstValueGm); dstIndexGlobal.SetGlobalBuffer((__gm__ uint32_t *)dstIndexGm); pipe.InitBuffer(queIn, 2, inBufferSize); pipe.InitBuffer(queOut, 2, outBufferSize); pipe.InitBuffer(queCalc, 1, calcBufferSize * sizeof(T)); pipe.InitBuffer(queTmp, 2, tmpBufferSize * sizeof(T)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> valueLocal = queIn.AllocTensor<T>(); AscendC::DataCopy(valueLocal, valueGlobal, elementCount); queIn.EnQue(valueLocal); AscendC::LocalTensor<uint32_t> indexLocal = queIn.AllocTensor<uint32_t>(); AscendC::DataCopy(indexLocal, indexGlobal, elementCount); queIn.EnQue(indexLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> valueLocal = queIn.DeQue<T>(); AscendC::LocalTensor<uint32_t> indexLocal = queIn.DeQue<uint32_t>(); AscendC::LocalTensor<T> sortedLocal = queCalc.AllocTensor<T>(); AscendC::LocalTensor<T> concatTmpLocal = queTmp.AllocTensor<T>(); AscendC::LocalTensor<T> sortTmpLocal = queTmp.AllocTensor<T>(); AscendC::LocalTensor<T> dstValueLocal = queOut.AllocTensor<T>(); AscendC::LocalTensor<uint32_t> dstIndexLocal = queOut.AllocTensor<uint32_t>(); AscendC::LocalTensor<T> concatLocal; AscendC::Concat(concatLocal, valueLocal, concatTmpLocal, concatRepeatTimes); AscendC::Sort<T, false>(sortedLocal, concatLocal, indexLocal, sortTmpLocal, sortRepeatTimes); uint32_t singleMergeTmpElementCount = elementCount / 4; uint32_t baseOffset = AscendC::GetSortOffset<T>(singleMergeTmpElementCount); AscendC::MrgSortSrcList sortList = AscendC::MrgSortSrcList(sortedLocal[0], sortedLocal[baseOffset], sortedLocal[2 * baseOffset], sortedLocal[3 * baseOffset]); uint16_t singleDataSize = elementCount / 4; const uint16_t elementCountList[4] = {singleDataSize, singleDataSize, singleDataSize, singleDataSize}; uint32_t sortedNum[4]; AscendC::MrgSort<T, false>(sortTmpLocal, sortList, elementCountList, sortedNum, 0b1111, 1); AscendC::Extract(dstValueLocal, dstIndexLocal, sortTmpLocal, extractRepeatTimes); queTmp.FreeTensor(concatTmpLocal); queTmp.FreeTensor(sortTmpLocal); queIn.FreeTensor(valueLocal); queIn.FreeTensor(indexLocal); queCalc.FreeTensor(sortedLocal); queOut.EnQue(dstValueLocal); queOut.EnQue(dstIndexLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstValueLocal = queOut.DeQue<T>(); AscendC::LocalTensor<uint32_t> dstIndexLocal = queOut.DeQue<uint32_t>(); AscendC::DataCopy(dstValueGlobal, dstValueLocal, elementCount); AscendC::DataCopy(dstIndexGlobal, dstIndexLocal, elementCount); queOut.FreeTensor(dstValueLocal); queOut.FreeTensor(dstIndexLocal); } private: AscendC::TPipe pipe; AscendC::TQue<AscendC::TPosition::VECIN, 2> queIn; AscendC::TQue<AscendC::TPosition::VECOUT, 2> queOut; AscendC::TQue<AscendC::TPosition::VECIN, 2> queTmp; AscendC::TQue<AscendC::TPosition::VECIN, 1> queCalc; AscendC::GlobalTensor<T> valueGlobal; AscendC::GlobalTensor<uint32_t> indexGlobal; AscendC::GlobalTensor<T> dstValueGlobal; AscendC::GlobalTensor<uint32_t> dstIndexGlobal; uint32_t elementCount = 128; uint32_t concatRepeatTimes; uint32_t inBufferSize; uint32_t outBufferSize; uint32_t calcBufferSize; uint32_t tmpBufferSize; uint32_t sortedLocalSize; uint32_t sortTmpLocalSize; uint32_t sortRepeatTimes; uint32_t extractRepeatTimes; }; extern "C" __global__ __aicore__ void sort_operator(__gm__ uint8_t *src0Gm, __gm__ uint8_t *src1Gm, __gm__ uint8_t *dst0Gm, __gm__ uint8_t *dst1Gm) { FullSort<half> op; op.Init(src0Gm, src1Gm, dst0Gm, dst1Gm); op.Process(); }示例结果 输入数据(srcValueGm): 128个float类型数据 [31 30 29 ... 2 1 0 63 62 61 ... 34 33 32 95 94 93 ... 66 65 64 127 126 125 ... 98 97 96] 输入数据(srcIndexGm): [31 30 29 ... 2 1 0 63 62 61 ... 34 33 32 95 94 93 ... 66 65 64 127 126 125 ... 98 97 96] 输出数据(dstValueGm): [127 126 125 ... 2 1 0] 输出数据(dstIndexGm): [127 126 125 ... 2 1 0]