* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file slice_move_align_gather.h
* \brief
*/
#ifndef SLICE_MOVE_ALIGN_GATHER_H
#define SLICE_MOVE_ALIGN_GATHER_H
#include <cstdlib>
#include <type_traits>
#include "slice_base.h"
namespace Slice
{
using namespace AscendC;
template <typename T, typename U, typename V = int8_t>
class SliceMoveAlignGather : public SliceBase<T, U, V>
{
public:
__aicore__ inline SliceMoveAlignGather(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR begin, GM_ADDR y, const SliceMoveAlignGatherTilingData* tdPtr, TPipe* pipe);
__aicore__ inline void Process();
private:
using RangeType = std::conditional_t<sizeof(T) <= sizeof(int16_t), int16_t,
std::conditional_t<sizeof(T) <= sizeof(int32_t), int32_t, int64_t>>;
using IdxType = std::conditional_t<sizeof(T) <= sizeof(int16_t), uint16_t,
std::conditional_t<sizeof(T) <= sizeof(int32_t), uint32_t, uint64_t>>;
using CastType =
std::conditional_t<sizeof(T) == 1, std::conditional_t<std::is_same_v<T, uint8_t>, uint16_t, int16_t>, T>;
private:
__aicore__ inline void ParseMoveAlignGatherTilingData(
GM_ADDR begin, const SliceMoveAlignGatherTilingData* tdPtr, int64_t blockIdx);
__aicore__ inline void ProcessPerBlock(const DataCopyExtParams& copyOutParamsMain,
const DataCopyExtParams& copyOutParamsTail);
__aicore__ inline void ParseCopyInTilingData();
__aicore__ inline void ParseLoopModeAndMoveAlignParams(LoopModeParams& loopMode, DataCopyExtParams& extParams);
__aicore__ inline void SetCopyOutAlignParams(DataCopyExtParams& copyOutParams, const DataCopyExtParams& copyInParam,
const LoopModeParams& loopMode, const uint32_t blockLen);
__aicore__ inline void GatherWithIndex(uint32_t blockCount);
__aicore__ inline void InitLoopIndex();
private:
TPipe* pipe_ = nullptr;
const SliceMoveAlignGatherTilingData* tdPtr_ = nullptr;
TQue<QuePosition::VECIN, 1> inQue_;
TQue<QuePosition::VECOUT, 1> outQue_;
TBuf<QuePosition::VECCALC> idxInitBuf_;
GlobalTensor<T> inputGM_;
GlobalTensor<T> outputGM_;
int64_t blockIdx_ = 0;
uint8_t bufferCnt_ = 2;
int64_t inUbDims_ = 0;
LoopModeParams loopMode_;
LoopModeParams loopModeTail_;
DataCopyExtParams copyInParam_;
DataCopyExtParams copyInParamTail_;
uint32_t outputMainblockLen_ = 0;
uint32_t outputTailblockLen_ = 0;
DataCopyPadExtParams<T> padParams_{false, 0, 0, 0};
uint32_t lenPerLoop_ = 0;
uint32_t loopLastOneDimInc_ = 0;
uint32_t loopLastTwoDimInc_ = 0;
LocalTensor<RangeType> loopLastOneDimInitTensor_;
LocalTensor<RangeType> loopLastTwoDimInitTensor_;
uint32_t lastOneDimOutputDim_ = 0;
uint32_t vlCnt_ = Ops::Base::GetVRegSize() / sizeof(T);
};
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::Init(GM_ADDR x, GM_ADDR begin, GM_ADDR y,
const SliceMoveAlignGatherTilingData* tdPtr, TPipe* pipe)
{
blockIdx_ = GetBlockIdx();
inputGM_.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(x));
outputGM_.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(y));
tdPtr_ = tdPtr;
this->ParseMoveAlignGatherTilingData(begin, tdPtr_, blockIdx_);
inUbDims_ = this->inputDims_ - this->ubIndex_;
ParseCopyInTilingData();
pipe_ = pipe;
pipe_->InitBuffer(inQue_, bufferCnt_, tdPtr_->ubSizeInput);
pipe_->InitBuffer(outQue_, bufferCnt_, this->ubSize_);
pipe_->InitBuffer(idxInitBuf_, 2 * VL_SIZE_BYTE);
InitLoopIndex();
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::ParseMoveAlignGatherTilingData(GM_ADDR begin,
const SliceMoveAlignGatherTilingData* tdPtr, int64_t blockIdx)
{
this->ParseBaseTilingData(begin, &(tdPtr->sliceBaseTilingData), blockIdx);
this->ubOutLoopSteps_ = tdPtr->ubOutLoopSteps;
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::InitLoopIndex()
{
lastOneDimOutputDim_ = this->outputShape_[this->inputDims_ - 1];
lenPerLoop_ = VL_SIZE_BYTE / sizeof(RangeType);
loopLastTwoDimInc_ = lenPerLoop_ / lastOneDimOutputDim_;
loopLastOneDimInc_ = lenPerLoop_ % lastOneDimOutputDim_;
loopLastOneDimInitTensor_ = idxInitBuf_.GetWithOffset<RangeType>(vlCnt_, 0);
loopLastTwoDimInitTensor_ = idxInitBuf_.GetWithOffset<RangeType>(vlCnt_, VL_SIZE_BYTE);
__local_mem__ RangeType* loopLastOneDimInitAddr = (__local_mem__ RangeType*)loopLastOneDimInitTensor_.GetPhyAddr();
__local_mem__ RangeType* loopLastTwoDimInitAddr = (__local_mem__ RangeType*)loopLastTwoDimInitTensor_.GetPhyAddr();
auto lastOneOutputDim = RangeType(lastOneDimOutputDim_);
__VEC_SCOPE__
{
MicroAPI::MaskReg mask = MicroAPI::CreateMask<RangeType, MicroAPI::MaskPattern::ALL>();
MicroAPI::RegTensor<RangeType> indexReg;
MicroAPI::RegTensor<RangeType> loopLastTwoDimInitReg;
MicroAPI::RegTensor<RangeType> loopLastOneDimInitReg;
MicroAPI::RegTensor<RangeType> lastOneOutputDimReg;
AscendC::MicroAPI::Duplicate(lastOneOutputDimReg, lastOneOutputDim, mask);
MicroAPI::Arange(indexReg, 0);
MicroAPI::Div(loopLastTwoDimInitReg, indexReg, lastOneOutputDimReg, mask);
MicroAPI::Muls(loopLastOneDimInitReg, loopLastTwoDimInitReg, lastOneOutputDim, mask);
MicroAPI::Sub(loopLastOneDimInitReg, indexReg, loopLastOneDimInitReg, mask);
MicroAPI::DataCopy(loopLastTwoDimInitAddr, loopLastTwoDimInitReg, mask);
MicroAPI::DataCopy(loopLastOneDimInitAddr, loopLastOneDimInitReg, mask);
}
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::Process()
{
if (blockIdx_ >= this->realCoreNum_) {
return;
}
DataCopyExtParams copyOutParamsMain;
DataCopyExtParams copyOutParamsTail;
SetCopyOutAlignParams(copyOutParamsMain, copyInParam_, loopMode_, outputMainblockLen_);
if (this->ubTailFactor_ > 0) {
SetCopyOutAlignParams(copyOutParamsTail, copyInParamTail_, loopModeTail_, outputTailblockLen_);
}
ProcessPerBlock(copyOutParamsMain, copyOutParamsTail);
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::ParseLoopModeAndMoveAlignParams(LoopModeParams& loopMode,
DataCopyExtParams& extParams)
{
loopMode.loop1Size = tdPtr_->moveAlignParams.loop1Size;
loopMode.loop2Size = tdPtr_->moveAlignParams.loop2Size;
loopMode.loop1SrcStride = tdPtr_->moveAlignParams.loop1SrcStride;
loopMode.loop1DstStride = tdPtr_->moveAlignParams.loop1DstStride;
loopMode.loop2SrcStride = tdPtr_->moveAlignParams.loop2SrcStride;
loopMode.loop2DstStride = tdPtr_->moveAlignParams.loop2DstStride;
extParams.blockCount = tdPtr_->moveAlignParams.blockCount;
extParams.blockLen = tdPtr_->moveAlignParams.blockLen;
extParams.srcStride = tdPtr_->moveAlignParams.srcStride;
extParams.dstStride = 0;
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::ParseCopyInTilingData()
{
outputMainblockLen_ = tdPtr_->outBlockLen;
outputTailblockLen_ = outputMainblockLen_;
ParseLoopModeAndMoveAlignParams(loopMode_, copyInParam_);
if (this->ubTailFactor_ > 0) {
ParseLoopModeAndMoveAlignParams(loopModeTail_, copyInParamTail_);
if (inUbDims_ == DIMS_2) {
copyInParamTail_.blockLen = copyInParam_.blockLen / this->ubFactor_ * this->ubTailFactor_;
outputTailblockLen_ = outputMainblockLen_ / this->ubFactor_ * this->ubTailFactor_;
} else if (inUbDims_ == DIMS_3) {
copyInParamTail_.blockCount = this->ubTailFactor_;
} else if (inUbDims_ == DIMS_4) {
loopModeTail_.loop1Size = this->ubTailFactor_;
}
}
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::SetCopyOutAlignParams(DataCopyExtParams& copyOutParams,
const DataCopyExtParams& copyInParam,
const LoopModeParams& loopMode,
const uint32_t blockLen)
{
copyOutParams.blockCount = copyInParam.blockCount * loopMode.loop1Size;
copyOutParams.blockLen = blockLen;
copyOutParams.dstStride = 0;
copyOutParams.srcStride = 0;
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::GatherWithIndex(uint32_t blockCount)
{
__local_mem__ RangeType* loopLastOneDimInitAddr = (__local_mem__ RangeType*)loopLastOneDimInitTensor_.GetPhyAddr();
__local_mem__ RangeType* loopLastTwoDimInitAddr = (__local_mem__ RangeType*)loopLastTwoDimInitTensor_.GetPhyAddr();
auto inTensor = inQue_.DeQue<T>();
__local_mem__ T* inAddr = (__local_mem__ T*)inTensor.GetPhyAddr();
auto outTensor = outQue_.AllocTensor<T>();
__local_mem__ T* outAddr = (__local_mem__ T*)outTensor.GetPhyAddr();
uint32_t lastOneDimBegin = this->begin_[this->inputDims_ - 1];
uint32_t lastOneInputDim = tdPtr_->lastOneInputDim;
auto lastOneOutputDim = RangeType(lastOneDimOutputDim_);
uint32_t perLoopInputPadElem =
Ops::Base::CeilAlign(tdPtr_->moveAlignParams.blockLen, uint32_t(BLOCK_SIZE_BYTE)) / sizeof(T);
uint32_t perLoopOutputPadElem =
Ops::Base::CeilAlign(tdPtr_->outBlockLen, uint32_t(BLOCK_SIZE_BYTE)) / sizeof(T);
uint32_t perOutBlkLenElem = tdPtr_->outBlockLen / sizeof(T);
uint32_t rVLCnt = VL_SIZE_BYTE / sizeof(T);
uint16_t times = CeilDivision(perOutBlkLenElem, rVLCnt);
if constexpr (sizeof(T) == 1) {
rVLCnt /= sizeof(int16_t);
perOutBlkLenElem *= sizeof(int16_t);
times *= sizeof(int16_t);
}
uint32_t loopLastOneDimInc = loopLastOneDimInc_;
uint32_t loopLastTwoDimInc = loopLastTwoDimInc_;
__VEC_SCOPE__
{
MicroAPI::RegTensor<RangeType> regIdx;
MicroAPI::RegTensor<T> regData;
MicroAPI::RegTensor<RangeType> regLastOneDimIdx;
MicroAPI::RegTensor<RangeType> regLastTwoDimIdx;
MicroAPI::MaskReg maskData;
MicroAPI::MaskReg cmpMaskReg;
MicroAPI::RegTensor<RangeType> regAllOne;
MicroAPI::RegTensor<RangeType> regAllZero;
MicroAPI::RegTensor<RangeType> regCarry;
MicroAPI::Duplicate(regAllOne, 1);
MicroAPI::Duplicate(regAllZero, 0);
MicroAPI::DataCopy(regLastOneDimIdx, loopLastOneDimInitAddr);
MicroAPI::DataCopy(regLastTwoDimIdx, loopLastTwoDimInitAddr);
uint32_t perLoopCount = perOutBlkLenElem;
for (uint16_t j = 0; j < times; j++) {
maskData = MicroAPI::UpdateMask<T>(perLoopCount);
MicroAPI::Muls(regIdx, regLastTwoDimIdx, RangeType(lastOneInputDim), maskData);
MicroAPI::Add(regIdx, regIdx, regLastOneDimIdx, maskData);
MicroAPI::Adds(regIdx, regIdx, RangeType(lastOneDimBegin), maskData);
for (uint16_t i = 0; i < (uint16_t)blockCount; i++) {
uint32_t blockCountInputOffest = i * perLoopInputPadElem;
uint32_t blockCountOutputOffest = i * perLoopOutputPadElem;
MicroAPI::DataCopyGather((MicroAPI::RegTensor<CastType>&)regData, inAddr + blockCountInputOffest,
(MicroAPI::RegTensor<IdxType>&)regIdx, maskData);
if constexpr (sizeof(T) != 1) {
MicroAPI::DataCopy(outAddr + blockCountOutputOffest + j * rVLCnt, regData,
maskData);
} else {
__local_mem__ CastType* outAddrB16 =
reinterpret_cast<__local_mem__ CastType*>(outAddr + blockCountOutputOffest + j * rVLCnt);
MicroAPI::DataCopy<CastType, MicroAPI::StoreDist::DIST_PACK_B16>(
outAddrB16, (MicroAPI::RegTensor<CastType>&)regData, maskData);
}
}
MicroAPI::Adds(regLastOneDimIdx, regLastOneDimIdx, RangeType(loopLastOneDimInc), maskData);
MicroAPI::Adds(regLastTwoDimIdx, regLastTwoDimIdx, RangeType(loopLastTwoDimInc), maskData);
MicroAPI::CompareScalar<RangeType, CMPMODE::GE>(cmpMaskReg, regLastOneDimIdx, lastOneOutputDim,
maskData);
MicroAPI::Select(regCarry, regAllOne, regAllZero, cmpMaskReg);
MicroAPI::Add(regLastTwoDimIdx, regLastTwoDimIdx, regCarry, maskData);
MicroAPI::Muls(regCarry, regCarry, lastOneOutputDim, maskData);
MicroAPI::Sub(regLastOneDimIdx, regLastOneDimIdx, regCarry, maskData);
}
}
inQue_.FreeTensor(inTensor);
outQue_.EnQue(outTensor);
}
template <typename T, typename U, typename V>
__aicore__ inline void SliceMoveAlignGather<T, U, V>::ProcessPerBlock(const DataCopyExtParams& copyOutParamsMain,
const DataCopyExtParams& copyOutParamsTail)
{
int64_t inputGmAddr = 0;
int64_t outputGmAddr = 0;
int64_t ubSplitLoopsNum = 0;
int64_t ubOuterLoopsNum = 0;
int64_t rowsOffsetOutput = 0;
this->CalcProcessLoopsNum(ubOuterLoopsNum, ubSplitLoopsNum, blockIdx_);
this->GetProcessRowsOffset(rowsOffsetOutput, blockIdx_);
for (int64_t idx = 0; idx < ubOuterLoopsNum; idx++) {
inputGmAddr =
this->GetInputGmAddrWithoutLastDim(rowsOffsetOutput + idx * this->rowsOffsetSteps_[this->ubIndex_]);
outputGmAddr = this->GetOutputGmAddr(rowsOffsetOutput + idx * this->rowsOffsetSteps_[this->ubIndex_]);
for (int64_t loops = 0; loops < ubSplitLoopsNum; loops++) {
LocalTensor<T> inputLocal = inQue_.AllocTensor<T>();
SetLoopModePara(loopMode_, DataCopyMVType::OUT_TO_UB);
DataCopyPad(inputLocal, inputGM_[inputGmAddr + loops * this->ubInLoopSteps_], copyInParam_, padParams_);
ResetLoopModePara(DataCopyMVType::OUT_TO_UB);
inQue_.EnQue(inputLocal);
GatherWithIndex(copyOutParamsMain.blockCount);
auto outTensor = outQue_.DeQue<T>();
DataCopyPad(outputGM_[outputGmAddr + loops * this->ubOutLoopSteps_], outTensor, copyOutParamsMain);
outQue_.FreeTensor(outTensor);
}
if (this->ubTailFactor_ > 0) {
LocalTensor<T> inputLocal = inQue_.AllocTensor<T>();
SetLoopModePara(loopModeTail_, DataCopyMVType::OUT_TO_UB);
DataCopyPad(inputLocal, inputGM_[inputGmAddr + ubSplitLoopsNum * this->ubInLoopSteps_], copyInParamTail_,
padParams_);
ResetLoopModePara(DataCopyMVType::OUT_TO_UB);
inQue_.EnQue(inputLocal);
GatherWithIndex(copyOutParamsTail.blockCount);
auto outTensor = outQue_.DeQue<T>();
DataCopyPad(outputGM_[outputGmAddr + ubSplitLoopsNum * this->ubOutLoopSteps_], outTensor,
copyOutParamsTail);
outQue_.FreeTensor(outTensor);
}
}
}
}
#endif