* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file roll_gather_simd.h
* \brief
*/
#ifndef __ROLL_GATHER_SIMD_H__
#define __ROLL_GATHER_SIMD_H__
#include "kernel_operator.h"
#include "roll_struct.h"
#include "op_kernel/platform_util.h"
namespace Roll {
using namespace AscendC;
constexpr int32_t GATHER_REG_SIZE = Ops::Base::GetVRegSize();
template <typename T, bool isShiftW = true>
class RollGatherSimd {
public:
__aicore__ inline RollGatherSimd(TPipe* pipe, const RollTilingData* tiling) : pipe_(pipe), tilingData_(tiling){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR workspace);
__aicore__ inline void CopyIn(int64_t index, int32_t len);
template <typename IntType, typename IndexType, bool isShiftH>
__aicore__ inline void CalculateFullCHWIndex(__ubuf__ IndexType* indexLocalAddr, uint16_t& tmplSize);
template <typename IntType, typename IndexType>
__aicore__ inline void CalculateFullHWIndex(__ubuf__ IndexType* indexLocalAddr, uint16_t& tmplSize);
template <typename IntType, typename IndexType>
__aicore__ inline void CalculateFullWIndex(__ubuf__ IndexType* indexLocalAddr, uint16_t& tmplSize);
__aicore__ inline void findTargetSplit(uint16_t& gatherTmplsplit, bool& isEqual);
__aicore__ inline void CopyOut(int64_t outindex, int32_t Count);
__aicore__ inline void ComputeOutIndex(
int64_t inputIndex, int64_t& outputIndex, uint16_t& gatherTmplsplit, int32_t& Count);
__aicore__ inline void Process();
private:
template <typename IndexType>
__aicore__ inline void Gather(
LocalTensor<T>& xTensor, LocalTensor<T>& gatherTmpl, int32_t Count, uint16_t tmplSize, int32_t addrShift);
private:
const RollTilingData* tilingData_;
TPipe* pipe_;
UbParam curCoreUbParam;
TQueBind<TPosition::GM, TPosition::VECIN, 1> xInQue_;
TQueBind<TPosition::VECOUT, TPosition::GM, 1> AlienBuf;
TBuf<TPosition::VECCALC> gatherMethodBuf;
GlobalTensor<T> xGm_;
GlobalTensor<T> yGm_;
int32_t hwLen;
int64_t curCoreBaseIndex_ = 0;
int64_t blockIdx_ = 0;
int64_t inputIndices_[MAX_DIM_NUM] = {0};
int32_t gatherKey = 10001;
};
template <typename T, bool isShiftW>
__aicore__ inline void RollGatherSimd<T, isShiftW>::Init(GM_ADDR x, GM_ADDR y, GM_ADDR workspace)
{
blockIdx_ = GetBlockIdx();
if (blockIdx_ == GetBlockNum() - 1) {
curCoreUbParam = tilingData_->tailCoreUbParam;
if (curCoreUbParam.UbFactor == 0) {
curCoreUbParam.UbFactor = curCoreUbParam.UbTailFactor;
}
curCoreUbParam.UbFactor *= tilingData_->strides[curCoreUbParam.UbSplitAxis];
curCoreUbParam.UbCount = (tilingData_->blockTailFactor * tilingData_->strides[tilingData_->blockSplitAxis] +
curCoreUbParam.UbFactor - 1) /
curCoreUbParam.UbFactor;
curCoreUbParam.UbTailFactor = tilingData_->blockTailFactor * tilingData_->strides[tilingData_->blockSplitAxis] -
(curCoreUbParam.UbCount - 1) * curCoreUbParam.UbFactor;
} else {
curCoreUbParam = tilingData_->mainCoreUbParam;
if (curCoreUbParam.UbFactor == 0) {
curCoreUbParam.UbFactor = curCoreUbParam.UbTailFactor;
}
curCoreUbParam.UbFactor *= tilingData_->strides[curCoreUbParam.UbSplitAxis];
curCoreUbParam.UbCount = (tilingData_->blockFactor * tilingData_->strides[tilingData_->blockSplitAxis] +
curCoreUbParam.UbFactor - 1) /
curCoreUbParam.UbFactor;
curCoreUbParam.UbTailFactor = tilingData_->blockFactor * tilingData_->strides[tilingData_->blockSplitAxis] -
(curCoreUbParam.UbCount - 1) * curCoreUbParam.UbFactor;
}
curCoreBaseIndex_ = tilingData_->blockFactor * blockIdx_ * tilingData_->strides[tilingData_->blockSplitAxis];
xGm_.SetGlobalBuffer((__gm__ T*)x);
yGm_.SetGlobalBuffer((__gm__ T*)y);
if (tilingData_->dimNum < 2) {
hwLen = tilingData_->shapes[tilingData_->dimNum - 1];
} else {
hwLen = tilingData_->shapes[tilingData_->dimNum - 2] * tilingData_->shapes[tilingData_->dimNum - 1];
}
pipe_->InitBuffer(xInQue_, BUF_NUM, curCoreUbParam.UbFactor * sizeof(T) + GATHER_REG_SIZE);
pipe_->InitBuffer(AlienBuf, BUF_NUM, curCoreUbParam.UbFactor * sizeof(T) + GATHER_REG_SIZE);
pipe_->InitBuffer(gatherMethodBuf, GATHER_REG_SIZE);
}
template <typename T, bool isShiftW>
__aicore__ inline void RollGatherSimd<T, isShiftW>::CopyIn(int64_t index, int32_t len)
{
LocalTensor<T> xTensor = xInQue_.AllocTensor<T>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = 1;
dataCopyParams.blockLen = len * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = 0;
DataCopyPadExtParams dataCopyPadParams{false, 0, 0, static_cast<T>(0)};
DataCopyPad(xTensor, xGm_[index], dataCopyParams, dataCopyPadParams);
xInQue_.EnQue<T>(xTensor);
}
template <typename T, bool isShiftW>
__aicore__ inline void RollGatherSimd<T, isShiftW>::CopyOut(int64_t outindex, int32_t Count)
{
LocalTensor<T> AlienTensor = AlienBuf.DeQue<T>();
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = 1;
dataCopyParams.blockLen = Count * sizeof(T);
dataCopyParams.srcStride = 0;
dataCopyParams.dstStride = 0;
DataCopyPad(yGm_[outindex], AlienTensor, dataCopyParams);
AlienBuf.FreeTensor<T>(AlienTensor);
}
template <typename T, bool isShiftW>
__aicore__ inline void RollGatherSimd<T, isShiftW>::Process()
{
LocalTensor<T> gatherTmpl = gatherMethodBuf.Get<T>();
uint16_t tmplSize;
uint16_t gatherTmplsplit;
uint16_t is_b8 = 0;
if constexpr (sizeof(T) == 1) {
is_b8 = 1;
}
if ((tilingData_->dimNum > 2) &&
(tilingData_->shapes[tilingData_->dimNum - 3] * hwLen * sizeof(T) * (is_b8 + 1) <= 256) &&
(tilingData_->shifts[tilingData_->dimNum - 3] != 0)) {
gatherTmplsplit = tilingData_->dimNum - 3;
if constexpr (sizeof(T) <= 2) {
if (tilingData_->shifts[tilingData_->dimNum - 2] == 0) {
CalculateFullCHWIndex<int16_t, uint16_t, false>((__ubuf__ uint16_t*)gatherTmpl.GetPhyAddr(), tmplSize);
} else {
CalculateFullCHWIndex<int16_t, uint16_t, true>((__ubuf__ uint16_t*)gatherTmpl.GetPhyAddr(), tmplSize);
}
} else {
if (tilingData_->shifts[tilingData_->dimNum - 2] == 0) {
CalculateFullCHWIndex<int32_t, uint32_t, false>((__ubuf__ uint32_t*)gatherTmpl.GetPhyAddr(), tmplSize);
} else {
CalculateFullCHWIndex<int32_t, uint32_t, true>((__ubuf__ uint32_t*)gatherTmpl.GetPhyAddr(), tmplSize);
}
}
} else if (hwLen * sizeof(T) * (is_b8 + 1) > GATHER_REG_SIZE || tilingData_->dimNum < 2) {
gatherTmplsplit = tilingData_->dimNum - 1;
if constexpr (sizeof(T) <= 2) {
CalculateFullWIndex<int16_t, uint16_t>((__ubuf__ uint16_t*)gatherTmpl.GetPhyAddr(), tmplSize);
} else {
CalculateFullWIndex<int32_t, uint32_t>((__ubuf__ uint32_t*)gatherTmpl.GetPhyAddr(), tmplSize);
}
} else {
gatherTmplsplit = tilingData_->dimNum - 2;
if constexpr (sizeof(T) <= 2) {
CalculateFullHWIndex<int16_t, uint16_t>((__ubuf__ uint16_t*)gatherTmpl.GetPhyAddr(), tmplSize);
} else {
CalculateFullHWIndex<int32_t, uint32_t>((__ubuf__ uint32_t*)gatherTmpl.GetPhyAddr(), tmplSize);
}
}
bool isEqual = false;
uint16_t movTmplsplit = gatherTmplsplit;
findTargetSplit(movTmplsplit, isEqual);
for (int32_t loopNum = 0; loopNum < curCoreUbParam.UbCount; loopNum++) {
int32_t curUbFactor;
if (loopNum == curCoreUbParam.UbCount - 1) {
curUbFactor = curCoreUbParam.UbTailFactor;
} else {
curUbFactor = curCoreUbParam.UbFactor;
}
int32_t CopyLoop = 1;
int32_t loopSize = curUbFactor;
if (!isEqual) {
CopyLoop =
curUbFactor / (tilingData_->strides[movTmplsplit] * tilingData_->shapes[movTmplsplit]);
loopSize = tilingData_->strides[movTmplsplit - 1];
}
CopyIn(curCoreBaseIndex_, curUbFactor);
LocalTensor<T> xTensor = xInQue_.DeQue<T>();
for (uint16_t i = 0; i < CopyLoop; i++) {
int32_t Count = 0;
int32_t countLoopSize = loopSize;
int32_t addrShift = i * loopSize;
while (countLoopSize > 0) {
int64_t outIndex = 0;
if (gatherTmplsplit != movTmplsplit) {
ComputeOutIndex(curCoreBaseIndex_ + addrShift, outIndex, movTmplsplit, Count);
} else {
Count = loopSize;
}
if (countLoopSize < Count) {
Count = countLoopSize;
}
if constexpr (sizeof(T) <= 2) {
Gather<uint16_t>(xTensor, gatherTmpl, Count, tmplSize, addrShift);
} else {
Gather<uint32_t>(xTensor, gatherTmpl, Count, tmplSize, addrShift);
}
CopyOut(outIndex, Count);
countLoopSize -= Count;
addrShift += Count;
}
}
xInQue_.FreeTensor<T>(xTensor);
curCoreBaseIndex_ += curCoreUbParam.UbFactor;
}
gatherMethodBuf.FreeTensor<T>(gatherTmpl);
}
template <typename T, bool isShiftW>
template <typename IntType, typename IndexType, bool isShiftH>
__aicore__ inline void RollGatherSimd<T, isShiftW>::CalculateFullCHWIndex(
__ubuf__ IndexType* indexLocalAddr, uint16_t& tmplSize)
{
int32_t shapesW = tilingData_->shapes[tilingData_->dimNum - 1];
int32_t shapesH = tilingData_->shapes[tilingData_->dimNum - 2];
int32_t shapesC = tilingData_->shapes[tilingData_->dimNum - 3];
int32_t shiftsW = tilingData_->shifts[tilingData_->dimNum - 1];
int32_t shiftsH = tilingData_->shifts[tilingData_->dimNum - 2];
int32_t shiftsC = tilingData_->shifts[tilingData_->dimNum - 3];
int32_t shapesLen = shapesW * shapesH * shapesC;
uint16_t loopSize = 0;
if constexpr (sizeof(T) == 1) {
loopSize = GATHER_REG_SIZE / sizeof(T) / 2 / shapesLen;
} else {
loopSize = GATHER_REG_SIZE / sizeof(T) / shapesLen;
}
tmplSize = loopSize * shapesLen;
IndexType copyLenth = shapesLen;
IntType startScalar = 0;
IntType negShiftWScalar = -1 * static_cast<IntType>(shiftsW);
IntType wScalar = static_cast<IntType>(shapesW);
IntType negShiftHWScalar = -1 * static_cast<IntType>(shapesW) * static_cast<IntType>(shiftsH);
IntType hwScalar = static_cast<IntType>(shapesW) * static_cast<IntType>(shapesH);
IntType negShiftCHWScalar = -1 * hwScalar * static_cast<IntType>(shiftsC);
IntType chwScalar = static_cast<IntType>(shapesLen);
uint32_t shiftCHWSizeScalar = static_cast<uint32_t>(hwScalar) * static_cast<uint32_t>(shiftsC);
__VEC_SCOPE__
{
AscendC::MicroAPI::RegTensor<IndexType> indexVReg;
AscendC::MicroAPI::RegTensor<IntType> calculateVReg, helpCmpVReg, helpAddVReg, helpDivVReg;
AscendC::MicroAPI::MaskReg cmpResultMaskReg, helpAddCHWMaskReg, fullMaskReg;
AscendC::MicroAPI::UnalignReg u0;
fullMaskReg = AscendC::MicroAPI::CreateMask<IndexType, AscendC::MicroAPI::MaskPattern::ALL>();
helpAddCHWMaskReg = AscendC::MicroAPI::UpdateMask<IndexType>(shiftCHWSizeScalar);
AscendC::MicroAPI::Duplicate(helpDivVReg, wScalar, fullMaskReg);
AscendC::MicroAPI::Arange(calculateVReg, startScalar);
if constexpr (isShiftW) {
AscendC::MicroAPI::Div(helpCmpVReg, calculateVReg, helpDivVReg, fullMaskReg);
AscendC::MicroAPI::Muls(helpCmpVReg, helpCmpVReg, wScalar, fullMaskReg);
AscendC::MicroAPI::Adds(calculateVReg, calculateVReg, negShiftWScalar, fullMaskReg);
AscendC::MicroAPI::Compare<IntType, CMPMODE::LT>(cmpResultMaskReg, calculateVReg, helpCmpVReg, fullMaskReg);
AscendC::MicroAPI::Duplicate(helpAddVReg, wScalar, cmpResultMaskReg);
AscendC::MicroAPI::Add(calculateVReg, calculateVReg, helpAddVReg, fullMaskReg);
}
if constexpr (isShiftH) {
AscendC::MicroAPI::Duplicate(helpDivVReg, hwScalar, fullMaskReg);
AscendC::MicroAPI::Div(helpCmpVReg, calculateVReg, helpDivVReg, fullMaskReg);
AscendC::MicroAPI::Muls(helpCmpVReg, helpCmpVReg, hwScalar, fullMaskReg);
AscendC::MicroAPI::Adds(calculateVReg, calculateVReg, negShiftHWScalar, fullMaskReg);
AscendC::MicroAPI::Compare<IntType, CMPMODE::LT>(cmpResultMaskReg, calculateVReg, helpCmpVReg, fullMaskReg);
AscendC::MicroAPI::Duplicate(helpAddVReg, hwScalar, cmpResultMaskReg);
AscendC::MicroAPI::Add(calculateVReg, calculateVReg, helpAddVReg, fullMaskReg);
}
AscendC::MicroAPI::Adds(calculateVReg, calculateVReg, negShiftCHWScalar, fullMaskReg);
AscendC::MicroAPI::Duplicate(helpAddVReg, shapesLen, helpAddCHWMaskReg);
AscendC::MicroAPI::Add(calculateVReg, calculateVReg, helpAddVReg, fullMaskReg);
AscendC::MicroAPI::Copy(indexVReg, (MicroAPI::RegTensor<IndexType>&)calculateVReg);
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth);
for (uint16_t i = 1; i < loopSize; i++) {
AscendC::MicroAPI::Adds(indexVReg, indexVReg, copyLenth, fullMaskReg);
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth);
}
AscendC::MicroAPI::DataCopyUnAlignPost(indexLocalAddr, u0, 0);
}
}
template <typename T, bool isShiftW>
template <typename IntType, typename IndexType>
__aicore__ inline void RollGatherSimd<T, isShiftW>::CalculateFullHWIndex(
__ubuf__ IndexType* indexLocalAddr, uint16_t& tmplSize)
{
int32_t shapesW = tilingData_->shapes[tilingData_->dimNum - 1];
int32_t shapseH = tilingData_->shapes[tilingData_->dimNum - 2];
int32_t shiftsW = tilingData_->shifts[tilingData_->dimNum - 1];
int32_t shiftsH = tilingData_->shifts[tilingData_->dimNum - 2];
uint16_t loopSize = GATHER_REG_SIZE / sizeof(T) / hwLen;
if constexpr (sizeof(T) == 1) {
loopSize = GATHER_REG_SIZE / sizeof(T) / 2 / hwLen;
}
tmplSize = loopSize * hwLen;
IndexType copyLenth = hwLen;
IntType startScalar = 0;
IntType negShiftWScalar = -1 * static_cast<IntType>(shiftsW);
IntType wScalar = static_cast<IntType>(shapesW);
IntType negShiftHWScalar = -1 * static_cast<IntType>(shiftsH) * static_cast<IntType>(shapesW);
IntType HWScalar = static_cast<IntType>(hwLen);
uint32_t shifthWSizeScalar = static_cast<uint32_t>(shiftsH) * static_cast<uint32_t>(shapesW);
__VEC_SCOPE__
{
AscendC::MicroAPI::RegTensor<IndexType> indexVReg;
AscendC::MicroAPI::RegTensor<IntType> calculateVReg, helpCmpVReg, helpAddVReg, helpDivVReg;
AscendC::MicroAPI::MaskReg cmpResultMaskReg, helpAddHWMaskReg, fullMaskReg;
AscendC::MicroAPI::UnalignReg u0;
fullMaskReg = AscendC::MicroAPI::CreateMask<IndexType, AscendC::MicroAPI::MaskPattern::ALL>();
helpAddHWMaskReg = AscendC::MicroAPI::UpdateMask<IndexType>(shifthWSizeScalar);
AscendC::MicroAPI::Duplicate(helpDivVReg, wScalar, fullMaskReg);
AscendC::MicroAPI::Arange(calculateVReg, startScalar);
if constexpr (isShiftW) {
AscendC::MicroAPI::Div(helpCmpVReg, calculateVReg, helpDivVReg, fullMaskReg);
AscendC::MicroAPI::Muls(helpCmpVReg, helpCmpVReg, wScalar, fullMaskReg);
AscendC::MicroAPI::Adds(calculateVReg, calculateVReg, negShiftWScalar, fullMaskReg);
AscendC::MicroAPI::Compare<IntType, CMPMODE::LT>(cmpResultMaskReg, calculateVReg, helpCmpVReg, fullMaskReg);
AscendC::MicroAPI::Duplicate(helpAddVReg, wScalar, cmpResultMaskReg);
AscendC::MicroAPI::Add(calculateVReg, calculateVReg, helpAddVReg, fullMaskReg);
}
AscendC::MicroAPI::Adds(calculateVReg, calculateVReg, negShiftHWScalar, fullMaskReg);
AscendC::MicroAPI::Duplicate(helpAddVReg, hwLen, helpAddHWMaskReg);
AscendC::MicroAPI::Add(calculateVReg, calculateVReg, helpAddVReg, fullMaskReg);
AscendC::MicroAPI::Copy(indexVReg, (MicroAPI::RegTensor<IndexType>&)calculateVReg);
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth);
for (uint16_t i = 1; i < loopSize; i++) {
AscendC::MicroAPI::Adds(indexVReg, indexVReg, copyLenth, fullMaskReg);
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth);
}
AscendC::MicroAPI::DataCopyUnAlignPost(indexLocalAddr, u0, 0);
}
}
template <typename T, bool isShiftW>
template <typename IntType, typename IndexType>
__aicore__ inline void RollGatherSimd<T, isShiftW>::CalculateFullWIndex(
__ubuf__ IndexType* indexLocalAddr, uint16_t& tmplSize)
{
int64_t shapesW = tilingData_->shapes[tilingData_->dimNum - 1];
int64_t shiftsW = tilingData_->shifts[tilingData_->dimNum - 1];
uint16_t wSize = static_cast<uint16_t>(shapesW);
uint16_t loopSize = GATHER_REG_SIZE / sizeof(T) / wSize;
if constexpr (sizeof(T) == 1) {
loopSize = GATHER_REG_SIZE / sizeof(T) / 2 / wSize;
}
tmplSize = loopSize * wSize;
IndexType copyLenth = wSize;
IntType startScalar = 0;
IntType negShiftWScalar = -1 * static_cast<IntType>(shiftsW);
IntType wScalar = static_cast<IntType>(shapesW);
uint32_t shiftwSizeScalar = static_cast<uint32_t>(shiftsW);
__VEC_SCOPE__
{
AscendC::MicroAPI::RegTensor<IndexType> indexVReg;
AscendC::MicroAPI::RegTensor<IntType> calculateVReg, helpAddVReg;
AscendC::MicroAPI::MaskReg helpAddHWMaskReg, fullMaskReg;
AscendC::MicroAPI::UnalignReg u0;
fullMaskReg = AscendC::MicroAPI::CreateMask<IndexType, AscendC::MicroAPI::MaskPattern::ALL>();
helpAddHWMaskReg = AscendC::MicroAPI::UpdateMask<IndexType>(shiftwSizeScalar);
AscendC::MicroAPI::Arange(calculateVReg, startScalar);
if constexpr (isShiftW) {
AscendC::MicroAPI::Adds(calculateVReg, calculateVReg, negShiftWScalar, fullMaskReg);
AscendC::MicroAPI::Duplicate(helpAddVReg, wScalar, helpAddHWMaskReg);
AscendC::MicroAPI::Add(calculateVReg, calculateVReg, helpAddVReg, fullMaskReg);
}
AscendC::MicroAPI::Copy(indexVReg, (MicroAPI::RegTensor<IndexType>&)calculateVReg);
if constexpr (isShiftW) {
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth);
for (uint16_t i = 1; i < loopSize; i++) {
AscendC::MicroAPI::Adds(indexVReg, indexVReg, copyLenth, fullMaskReg);
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth);
}
} else {
AscendC::MicroAPI::DataCopyUnAlign(indexLocalAddr, indexVReg, u0, copyLenth * loopSize);
}
AscendC::MicroAPI::DataCopyUnAlignPost(indexLocalAddr, u0, 0);
}
}
template <typename T, bool isShiftW>
__aicore__ inline void RollGatherSimd<T, isShiftW>::findTargetSplit(uint16_t& movTmplsplit, bool& isEqual)
{
for (int32_t i = movTmplsplit - 1; i >= curCoreUbParam.UbSplitAxis; i--) {
if (tilingData_->shifts[i] != 0) {
movTmplsplit = static_cast<uint16_t>(i);
isEqual = (i == curCoreUbParam.UbSplitAxis);
return;
}
}
movTmplsplit = static_cast<uint16_t>(curCoreUbParam.UbSplitAxis);
isEqual = true;
return;
}
template <typename T, bool isShiftW>
__aicore__ inline void RollGatherSimd<T, isShiftW>::ComputeOutIndex(
int64_t inputIndex, int64_t& outputIndex, uint16_t& movTmplsplit, int32_t& Count)
{
for (int64_t dim = 0; dim < tilingData_->dimNum; dim++) {
inputIndices_[dim] = inputIndex / tilingData_->strides[dim];
inputIndex = inputIndex % tilingData_->strides[dim];
if (dim < movTmplsplit + 1) {
outputIndex +=
(inputIndices_[dim] + tilingData_->shifts[dim]) % tilingData_->shapes[dim] * tilingData_->strides[dim];
}
}
int64_t currentPos = inputIndices_[movTmplsplit];
int64_t shiftPoint = tilingData_->shapes[movTmplsplit] - tilingData_->shifts[movTmplsplit];
int64_t axisSize = tilingData_->shapes[movTmplsplit];
if (currentPos < shiftPoint) {
Count = (shiftPoint - currentPos) * tilingData_->strides[movTmplsplit];
} else {
Count = (axisSize - currentPos) * tilingData_->strides[movTmplsplit];
}
}
template <typename T, bool isShiftW>
template <typename IndexType>
__aicore__ inline void RollGatherSimd<T, isShiftW>::Gather(
LocalTensor<T>& xTensor, LocalTensor<T>& gatherTmpl, int32_t Count, uint16_t tmplSize, int32_t addrShift)
{
LocalTensor<T> AlienTensor = AlienBuf.AllocTensor<T>();
auto dstPtr = (__ubuf__ T*)AlienTensor.GetPhyAddr();
auto srcPtr = (__ubuf__ T*)xTensor.GetPhyAddr() + addrShift;
auto indexPtr = (__ubuf__ IndexType*)gatherTmpl.GetPhyAddr();
uint16_t loopSize = (Count + tmplSize - 1) / tmplSize;
uint32_t tmpMaskNum = tmplSize;
__VEC_SCOPE__
{
MicroAPI::RegTensor<IndexType> indexReg;
MicroAPI::RegTensor<T> dstReg0;
MicroAPI::RegTensor<T> dstReg1;
MicroAPI::MaskReg maskReg;
MicroAPI::UnalignReg u0;
if constexpr (sizeof(T) == 1) {
maskReg = MicroAPI::UpdateMask<uint16_t>(tmpMaskNum);
} else {
maskReg = MicroAPI::UpdateMask<T>(tmpMaskNum);
}
MicroAPI::DataCopy(indexReg, indexPtr);
for (uint16_t i = 0; i < loopSize; i++) {
if constexpr (sizeof(T) == 1) {
MicroAPI::DataCopyGather(
(MicroAPI::RegTensor<uint16_t>&)dstReg0, srcPtr + i * tmplSize, indexReg, maskReg);
MicroAPI::Pack<uint8_t, uint16_t, MicroAPI::HighLowPart::LOWEST>(
(MicroAPI::RegTensor<uint8_t>&)dstReg1, (MicroAPI::RegTensor<uint16_t>&)dstReg0);
} else {
MicroAPI::DataCopyGather(dstReg1, srcPtr + i * tmplSize, indexReg, maskReg);
}
MicroAPI::DataCopyUnAlign(dstPtr, (MicroAPI::RegTensor<T>&)dstReg1, u0, tmplSize);
}
AscendC::MicroAPI::DataCopyUnAlignPost(dstPtr, u0, 0);
}
AlienBuf.EnQue<T>(AlienTensor);
}
}
#endif