* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* \file gaussian_sort.cpp
* \brief gaussian sort op kernel
*/
#include "gaussian_sort_common.h"
#include "kernel_operator.h"
using namespace GaussianSortCommon;
using namespace AscendC;
class GaussianSort {
public:
__aicore__ inline GaussianSort() {}
__aicore__ inline void GetTilingData(GaussianSortTilingData tiling_data)
{
batchSize_ = tiling_data.batchSize;
cameraNum_ = tiling_data.cameraNum;
tileNum_ = tiling_data.tileNum;
gaussNum_ = tiling_data.gaussNum;
scheduleNum_ = tiling_data.scheduleNum;
maxSortNum_ = tiling_data.maxSortNum;
maxMaskNum_ = tiling_data.maxMaskNum;
}
__aicore__ inline void Init(GM_ADDR lb_sched, GM_ADDR gaussian_cnt, GM_ADDR depths, GM_ADDR gs_ids,
GM_ADDR sorted_offset, GM_ADDR sorted_gs_ids, GM_ADDR userWorkspace,
GaussianSortTilingData tiling_data)
{
blockNum_ = GetBlockNum();
ASSERT(blockNum_ != 0 && "Block Dim can not be Zero!");
GetTilingData(tiling_data);
blockIndex_ = GetBlockIdx();
coreOffsetsGM_.SetGlobalBuffer((__gm__ int64_t*)lb_sched);
scheduleGM_ = coreOffsetsGM_[blockNum_];
tileOffsetsGM_ = scheduleGM_[tileNum_];
tileCntGaussGM_.SetGlobalBuffer((__gm__ int32_t*)gaussian_cnt);
depthsGM_.SetGlobalBuffer((__gm__ float*)depths);
gsIdsGM_.SetGlobalBuffer((__gm__ float*)gs_ids);
sortedOffsetGM_.SetGlobalBuffer((__gm__ int64_t*)sorted_offset);
sortedGsIdsGM_.SetGlobalBuffer((__gm__ int32_t*)sorted_gs_ids);
int64_t vectorWSOffset = maxMaskNum_ * blockIndex_ * MRGSORT_WS_TENSOR_NUM;
GM_ADDR sortedTmpInWS = userWorkspace + vectorWSOffset * sizeof(float);
sortedTmpInWS_.SetGlobalBuffer((__gm__ float*)sortedTmpInWS);
}
__aicore__ inline void InitSortUB()
{
pipe_.Reset();
pipe_.InitBuffer(inQueueGsIds_, BUFFER_NUM, sortProcessNum_ * sizeof(float));
pipe_.InitBuffer(inQueueDepths_, BUFFER_NUM, sortProcessNum_ * sizeof(float));
pipe_.InitBuffer(outQueueSortedGsIds_, BUFFER_NUM, sortProcessNum_ * sizeof(int32_t));
uint32_t buffSize = sortProcessNum_ * sizeof(float);
pipe_.InitBuffer(sortGsIdsTmpBuf_, buffSize);
pipe_.InitBuffer(sortedTmpBuf_, buffSize * KVFACTOR);
pipe_.InitBuffer(sortTmpBuf_, buffSize * KVFACTOR);
sortedInLocal_ = sortedTmpBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
}
__aicore__ inline void InitMrgSortUB()
{
pipe_.Reset();
uint32_t buffSize = sortProcessNum_ * sizeof(float);
pipe_.InitBuffer(outQueueSortedGsIds_, BUFFER_NUM, buffSize);
pipe_.InitBuffer(wsSortedInBuf_, buffSize * KVFACTOR);
pipe_.InitBuffer(wsSortedTargetInBuf_, buffSize * KVFACTOR);
pipe_.InitBuffer(wsSortedTargetOutBuf_, buffSize * KVFACTOR * MRGSORT_OUT_MULT_NUM);
pipe_.InitBuffer(wsSortedOutBuf_, buffSize);
pipe_.InitBuffer(sortTmpBuf_, buffSize * KVFACTOR);
pipe_.InitBuffer(wsSortedDepthsBuf_, buffSize);
sortedInLocal_ = wsSortedInBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
sortedTargetInLocal_ = wsSortedTargetInBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
sortedTargetOutLocal_ =
wsSortedTargetOutBuf_.Get<float>(GetSortLen<float>(sortProcessNum_ * MRGSORT_OUT_MULT_NUM));
}
__aicore__ inline void CopyInMrgSortGsFromWS(uint32_t targetOffset, uint32_t compareOffset)
{
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
DataCopy(sortedTargetInLocal_, sortedTmpInWS_[targetOffset], GetSortLen<float>(sortNumPerLoop_));
DataCopy(sortedInLocal_, sortedTmpInWS_[compareOffset], GetSortLen<float>(sortProcessNum_));
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
}
__aicore__ inline void DataCopyIn(uint32_t loopId)
{
uint64_t offset = sortOffset_ + loopId * sortNumPerLoop_;
LocalTensor<float> depthsLocal = inQueueDepths_.AllocTensor<float>();
LocalTensor<float> gsIdsLocal = inQueueGsIds_.AllocTensor<float>();
if (loopId == sortLoopNum_ - 1) {
Duplicate(depthsLocal, MAX_FP32, sortProcessNum_);
}
SetWaitFlag<HardEvent::V_MTE2>(HardEvent::V_MTE2);
if (sortAlignedNum_ > 0 && loopId == sortLoopNum_ - 1) {
uint32_t blockLen = sortMoveNum_ * sizeof(float);
DataCopyExtParams copyParams{1, blockLen, 0, 0, 0};
DataCopyPadExtParams<float> depthsPadParams{true, 0, (uint8_t)sortAlignedNum_, MAX_FP32};
DataCopyPad(depthsLocal, depthsGM_[offset], copyParams, depthsPadParams);
DataCopyPadExtParams<float> idsPadParams{true, 0, (uint8_t)sortAlignedNum_, 0};
DataCopyPad(gsIdsLocal, gsIdsGM_[offset], copyParams, idsPadParams);
} else {
DataCopy(depthsLocal, depthsGM_[offset], sortMoveNum_);
DataCopy(gsIdsLocal, gsIdsGM_[offset], sortMoveNum_);
}
inQueueDepths_.EnQue<float>(depthsLocal);
inQueueGsIds_.EnQue<float>(gsIdsLocal);
SetWaitFlag<HardEvent::MTE2_V>(HardEvent::MTE2_V);
}
__aicore__ inline void CopyOutSortedGsToWS(uint32_t loopId)
{
uint32_t sortedOffset = GetSortOffset<float>(loopId * sortNumPerLoop_);
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
DataCopy(sortedTmpInWS_[sortedOffset], sortedInLocal_, GetSortLen<float>(sortProcessNum_));
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
}
__aicore__ inline void CopyOutMrgSortedGsToWS(uint32_t targetOffset, uint32_t compareOffset)
{
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
DataCopy(sortedTmpInWS_[targetOffset], sortedTargetOutLocal_[0], GetSortLen<float>(sortNumPerLoop_));
uint32_t sortOffset = GetSortOffset<float>(sortNumPerLoop_);
DataCopy(sortedTmpInWS_[compareOffset], sortedTargetOutLocal_[sortOffset], GetSortLen<float>(sortProcessNum_));
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
}
__aicore__ inline void DataCopyOut()
{
SetWaitFlag<HardEvent::V_MTE3>(HardEvent::V_MTE3);
LocalTensor<int32_t> sortedGsIdsLocal = outQueueSortedGsIds_.DeQue<int32_t>();
uint32_t blockLen = sortMoveNum_ * sizeof(int32_t);
DataCopyExtParams copyParams{1, blockLen, 0, 0, 0};
DataCopyPad(sortedGsIdsGM_[sortedOffset_], sortedGsIdsLocal, copyParams);
outQueueSortedGsIds_.FreeTensor(sortedGsIdsLocal);
SetWaitFlag<HardEvent::MTE3_V>(HardEvent::MTE3_V);
sortedOffset_ += sortMoveNum_;
}
__aicore__ inline void SortSingle(uint32_t loopId)
{
DataCopyIn(loopId);
LocalTensor<float> depthsLocal = inQueueDepths_.DeQue<float>();
LocalTensor<float> gsIdsLocal = inQueueGsIds_.DeQue<float>();
Muls(depthsLocal, depthsLocal, -1.0f, sortProcessNum_);
LocalTensor<float> depthsLocalTmp = depthsLocal;
LocalTensor<float> tempTensor = sortTmpBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
Concat(depthsLocalTmp, depthsLocal, tempTensor, sortProcessNum_ / ONE_REPEAT_CONCAT_NUM);
LocalTensor<int32_t> gsIdsLocalTmp = sortGsIdsTmpBuf_.Get<int32_t>(sortProcessNum_);
Cast(gsIdsLocalTmp, gsIdsLocal, RoundMode::CAST_TRUNC, sortProcessNum_);
Sort<float, true>(sortedInLocal_, depthsLocalTmp, gsIdsLocalTmp.ReinterpretCast<uint32_t>(), tempTensor,
sortProcessNum_ / ONE_REPEAT_SORT_NUM);
CopyOutSortedGsToWS(loopId);
}
__aicore__ inline void MrgSortInUB(uint32_t targetId, uint32_t compareId)
{
uint32_t targetOffset = GetSortOffset<float>(targetId * sortNumPerLoop_);
uint32_t compareOffset = GetSortOffset<float>(compareId * sortNumPerLoop_);
CopyInMrgSortGsFromWS(targetOffset, compareOffset);
uint16_t validBit = 0b11;
int32_t repeatTimes = 1;
const uint16_t elementCountList[4] = {static_cast<uint16_t>(sortNumPerLoop_),
static_cast<uint16_t>(sortProcessNum_), static_cast<uint16_t>(0),
static_cast<uint16_t>(0)};
uint32_t sortedNum[4];
MrgSortSrcList sortList =
MrgSortSrcList(sortedTargetInLocal_, sortedInLocal_, sortedTargetInLocal_, sortedTargetInLocal_);
MrgSort<float, false>(sortedTargetOutLocal_, sortList, elementCountList, sortedNum, validBit, repeatTimes);
CopyOutMrgSortedGsToWS(targetOffset, compareOffset);
}
__aicore__ inline void MrgSortExtraceInUB(uint32_t offset)
{
LocalTensor<float> sortedDepthsLocal = wsSortedDepthsBuf_.Get<float>(sortProcessNum_);
LocalTensor<float> sortedOutLocal = wsSortedOutBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
LocalTensor<int32_t> sortedGsIdsLocal = outQueueSortedGsIds_.AllocTensor<int32_t>();
DataCopy(sortedOutLocal, sortedTargetOutLocal_[offset], GetSortLen<float>(sortProcessNum_));
Extract(sortedDepthsLocal, sortedGsIdsLocal.ReinterpretCast<uint32_t>(), sortedOutLocal,
sortProcessNum_ / ONE_REPEAT_SORT_NUM);
outQueueSortedGsIds_.EnQue<int32_t>(sortedGsIdsLocal);
}
__aicore__ inline void SortTiling()
{
uint32_t subSeqSortNum = maxSortNum_ / MRGSORT_OUT_MULT_NUM;
uint32_t subSeqSortNumAlign = Align<uint32_t>(subSeqSortNum, sizeof(float));
sortNumPerLoop_ = (subSeqSortNumAlign / ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
sortNumPerLoop_ = sortNumPerLoop_ ? sortNumPerLoop_ : ONE_REPEAT_SORT_NUM;
sortLoopNum_ = Ceil<uint32_t>(sortTileNum_, sortNumPerLoop_);
sortTailNum_ = sortTileNum_ % sortNumPerLoop_;
uint32_t sortTailNumAlign = Align<uint32_t>(sortTailNum_, sizeof(float));
sortTailSortNum_ = Ceil<uint32_t>(sortTailNumAlign, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
sortAlignedNum_ = sortTailNumAlign - sortTailNum_;
sortTailSortNum_ = sortTailSortNum_ ? sortTailSortNum_ : sortNumPerLoop_;
sortTailNum_ = sortTailNum_ ? sortTailNum_ : sortNumPerLoop_;
}
__aicore__ inline void SubQueSort()
{
sortProcessNum_ = sortNumPerLoop_;
sortMoveNum_ = sortNumPerLoop_;
for (uint32_t loopId = 0; loopId < sortLoopNum_; loopId++) {
if (loopId == sortLoopNum_ - 1) {
sortProcessNum_ = sortTailSortNum_;
sortMoveNum_ = sortTailNum_;
}
InitSortUB();
SortSingle(loopId);
}
}
__aicore__ inline void BubleMrgSort()
{
for (uint32_t i = 0; i < sortLoopNum_; i++) {
uint32_t moveOutOffset = 0;
sortProcessNum_ = sortNumPerLoop_;
for (uint32_t j = i + 1; j < sortLoopNum_; j++) {
InitMrgSortUB();
if (j == sortLoopNum_ - 1) {
sortProcessNum_ = sortTailSortNum_;
}
MrgSortInUB(i, j);
}
sortProcessNum_ = sortNumPerLoop_;
sortMoveNum_ = sortNumPerLoop_;
if (i == sortLoopNum_ - 1) {
moveOutOffset = GetSortOffset<float>(sortNumPerLoop_);
sortProcessNum_ = sortTailSortNum_;
sortMoveNum_ = sortTailNum_;
}
MrgSortExtraceInUB(moveOutOffset);
DataCopyOut();
}
}
__aicore__ inline void SortInGM()
{
SortTiling();
SubQueSort();
BubleMrgSort();
}
__aicore__ inline void SortInUB()
{
InitSortUB();
sortLoopNum_ = 1;
sortMoveNum_ = sortTileNum_;
DataCopyIn(0);
LocalTensor<float> depthsLocal = inQueueDepths_.DeQue<float>();
LocalTensor<float> gsIdsLocal = inQueueGsIds_.DeQue<float>();
Muls(depthsLocal, depthsLocal, -1.0f, sortProcessNum_);
LocalTensor<float> sortedLocal = sortedTmpBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
LocalTensor<float> depthsLocalTmp = depthsLocal;
LocalTensor<float> tempTensor = sortTmpBuf_.Get<float>(GetSortLen<float>(sortProcessNum_));
Concat(depthsLocalTmp, depthsLocal, tempTensor, sortProcessNum_ / ONE_REPEAT_CONCAT_NUM);
LocalTensor<int32_t> sortedGsIdsLocal = outQueueSortedGsIds_.AllocTensor<int32_t>();
LocalTensor<int32_t> gsIdsLocalTmp = sortedGsIdsLocal;
Cast(gsIdsLocalTmp, gsIdsLocal, RoundMode::CAST_TRUNC, sortProcessNum_);
Sort<float, true>(sortedLocal, depthsLocalTmp, gsIdsLocalTmp.ReinterpretCast<uint32_t>(), tempTensor,
sortProcessNum_ / ONE_REPEAT_SORT_NUM);
LocalTensor<float> sortedDepthsLocal = depthsLocal;
Extract(depthsLocal, sortedGsIdsLocal.ReinterpretCast<uint32_t>(), sortedLocal,
sortProcessNum_ / ONE_REPEAT_SORT_NUM);
outQueueSortedGsIds_.EnQue<int32_t>(sortedGsIdsLocal);
inQueueDepths_.FreeTensor(depthsLocal);
inQueueGsIds_.FreeTensor(gsIdsLocal);
DataCopyOut();
}
__aicore__ inline void TileSort()
{
uint32_t tileSortGsNumAlign = Align<uint32_t>(sortTileNum_, sizeof(float));
sortAlignedNum_ = tileSortGsNumAlign - sortTileNum_;
sortProcessNum_ = Ceil<uint32_t>(tileSortGsNumAlign, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM;
if (sortProcessNum_ <= maxSortNum_) {
SortInUB();
} else {
SortInGM();
}
}
__aicore__ inline void LoopProcess()
{
for (uint32_t batchIdx = 0; batchIdx < batchSize_; batchIdx++) {
for (uint32_t cameraIdx = 0; cameraIdx < cameraNum_; cameraIdx++) {
int64_t baseOffset = batchIdx * cameraNum_ + cameraIdx;
int64_t baseScheduleOffset = baseOffset * scheduleNum_;
int64_t startScheduleIdx = 0;
if (blockIndex_ > 0) {
startScheduleIdx = coreOffsetsGM_.GetValue(baseScheduleOffset + blockIndex_ - 1);
}
int64_t endScheduleIdx = coreOffsetsGM_.GetValue(baseScheduleOffset + blockIndex_);
int64_t sortedOffsetIdx = baseOffset;
int64_t baseSortedOffset = 0;
if (sortedOffsetIdx > 0) {
baseSortedOffset = sortedOffsetGM_.GetValue(sortedOffsetIdx - 1);
}
sortedOffset_ = 0;
sortOffset_ = 0;
for (int64_t scheduleIdx = startScheduleIdx; scheduleIdx < endScheduleIdx; scheduleIdx++) {
int64_t tileIdx = scheduleGM_.GetValue(baseScheduleOffset + scheduleIdx);
int64_t tileOffset = baseOffset * tileNum_ + tileIdx;
int64_t sortedTileOffset = 0;
if (tileIdx > 0) {
sortedTileOffset = tileOffsetsGM_.GetValue(baseScheduleOffset + tileIdx - 1);
}
sortOffset_ = tileOffset * gaussNum_;
sortedOffset_ = baseSortedOffset + sortedTileOffset;
sortTileNum_ = tileCntGaussGM_.GetValue(tileOffset);
if (sortTileNum_ > 0) {
TileSort();
}
}
}
}
}
private:
TPipe pipe_;
TQue<QuePosition::VECIN, QUEUE_DEPTHS_NUM> inQueueGsIds_, inQueueDepths_;
TQue<QuePosition::VECOUT, QUEUE_DEPTHS_NUM> outQueueSortedGsIds_;
TBuf<TPosition::VECCALC> sortTmpBuf_, sortGsIdsTmpBuf_, sortedTmpBuf_;
TBuf<TPosition::VECCALC> wsSortedInBuf_, wsSortedTargetInBuf_, wsSortedTargetOutBuf_, wsSortedOutBuf_,
wsSortedDepthsBuf_;
GlobalTensor<int64_t> coreOffsetsGM_, scheduleGM_, tileOffsetsGM_, sortedOffsetGM_;
GlobalTensor<int32_t> tileCntGaussGM_;
GlobalTensor<float> depthsGM_, gsIdsGM_;
GlobalTensor<int32_t> sortedGsIdsGM_;
GlobalTensor<float> sortedTmpInWS_;
LocalTensor<float> sortedInLocal_;
LocalTensor<float> sortedTargetInLocal_, sortedTargetOutLocal_;
uint32_t blockIndex_, blockNum_;
uint32_t batchSize_, cameraNum_, tileNum_, gaussNum_, scheduleNum_, maxMaskNum_, maxSortNum_;
uint32_t sortTileNum_, sortAlignedNum_, sortMoveNum_, sortProcessNum_, sortLoopNum_, sortNumPerLoop_,
sortTailSortNum_, sortTailNum_;
uint64_t sortOffset_;
uint64_t sortedOffset_;
};
extern "C" __global__ __aicore__ void gaussian_sort(GM_ADDR lb_sched, GM_ADDR gaussian_cnt, GM_ADDR depths,
GM_ADDR gs_ids, GM_ADDR sorted_offset, GM_ADDR sorted_gs_ids,
GM_ADDR workspace, GM_ADDR tiling)
{
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
GET_TILING_DATA(tiling_data, tiling);
GM_ADDR userWorkspace = GetUserWorkspace(workspace);
GaussianSort op;
op.Init(lb_sched, gaussian_cnt, depths, gs_ids, sorted_offset, sorted_gs_ids, userWorkspace, tiling_data);
op.LoopProcess();
}