* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file grid_sampler_3d.h
* \brief
*/
#ifndef GIRD_SAMPLER_3D
#define GIRD_SAMPLER_3D
#if ASC_DEVKIT_MAJOR >=9
#include "kernel_vec_intf.h"
#else
#include "kernel_operator.h"
#endif
#include "kernel_tiling/kernel_tiling.h"
#include "grid_sampler_3d_common.h"
namespace GridSample {
using namespace AscendC;
struct CoordinatesStruct {
LocalTensor<int32_t> tnwCoordinates;
LocalTensor<int32_t> tneCoordinates;
LocalTensor<int32_t> tswCoordinates;
LocalTensor<int32_t> tsecoordinates;
LocalTensor<int32_t> bnwCoordinates;
LocalTensor<int32_t> bneCoordinates;
LocalTensor<int32_t> bswCoordinates;
LocalTensor<int32_t> bsecoordinates;
};
struct WeightStruct {
LocalTensor<float> tnwWeightLocal;
LocalTensor<float> tneWeightLocal;
LocalTensor<float> tswWeightLocal;
LocalTensor<float> tseWeightLocal;
LocalTensor<float> bnwWeightLocal;
LocalTensor<float> bneWeightLocal;
LocalTensor<float> bswWeightLocal;
LocalTensor<float> bseWeightLocal;
};
struct CalculatePointBilinearParam {
LocalTensor<int32_t> coordinatesUb;
LocalTensor<float> outValueUb;
LocalTensor<float> outValueTotalLocal;
LocalTensor<float> weightUb;
LocalTensor<uint64_t> maskUbTmp;
bool isAtomicAdd;
__aicore__ inline CalculatePointBilinearParam()
{}
__aicore__ inline CalculatePointBilinearParam(LocalTensor<int32_t> coordinatesUb, LocalTensor<float> outValueUb,
LocalTensor<float> outValueTotalLocal, LocalTensor<float> weightUb, LocalTensor<uint64_t> maskUbTmp,
bool isAtomicAdd)
: coordinatesUb(coordinatesUb), outValueUb(outValueUb), outValueTotalLocal(outValueTotalLocal),
weightUb(weightUb), maskUbTmp(maskUbTmp), isAtomicAdd(isAtomicAdd)
{}
};
template <typename T>
class GridSampler3D {
public:
__aicore__ inline GridSampler3D(){};
__aicore__ inline void Init(
GM_ADDR x, GM_ADDR gird, GM_ADDR y, GM_ADDR workspace, const GridSampleTilingData *tilingData, TPipe pipeIn);
__aicore__ inline void Process();
private:
__aicore__ inline void InitBuffer();
__aicore__ inline void PerLoopCompute(ProcessParam processParam);
__aicore__ inline void Clip(LocalTensor<float> iXFpUb, LocalTensor<float> iYFpUb, LocalTensor<float> iZFpUb);
__aicore__ inline void ZeroClip(LocalTensor<float> iXFpUb, LocalTensor<float> iYFpUb, LocalTensor<float> iZFpUb);
__aicore__ inline void MTE2ForNCHW(
int32_t nIdx, PointParam pointBilinearParam, LocalTensor<int32_t> coorUb, LocalTensor<T> xLocal);
__aicore__ inline void MTE2ForNHWC(
int32_t nIdx, PointParam pointBilinearParam, LocalTensor<int32_t> coorUb, LocalTensor<T> xLocal);
__aicore__ inline void OutTransposeFp16(int32_t channelAlign, LocalTensor<T> xLocal, LocalTensor<T> outValueUb);
__aicore__ inline void calculatePointBilinear(ProcessParam processParam,
CalculatePointBilinearParam calculatePointBilinearParam, PointParam pointBilinearParam, LocalTensor<T> xLocal);
__aicore__ inline void MTE3ForNCHW(
ProcessParam processParam, PointParam pointBilinearParam, LocalTensor<T> outValueUb);
__aicore__ inline void PointBilinear(ProcessParam processParam, LocalTensor<float> outValueUb, bool isAutomicAdd);
__aicore__ inline void GetXLocal(ProcessParam processParam, CalculatePointBilinearParam calculatePointBilinearParam,
PointParam pointBilinearParam, LocalTensor<T> xLocal, LocalTensor<float> outValueLocal);
__aicore__ inline void CalculateGrid(ProcessParam processParam, LocalTensor<float> gridFp32Local,
LocalTensor<float> inputXFpLocal, LocalTensor<float> inputYFpLocal, LocalTensor<float> inputZFpLocal);
__aicore__ inline void calculateGridWeight(
LocalTensor<float> inputXFpLocal, LocalTensor<float> inputYFpLocal, LocalTensor<float> inputZFpLocal);
__aicore__ inline void clipAllCoordinate(LocalTensor<float> outValueLocal);
__aicore__ inline CoordinatesStruct GetCoordinatesStruct();
__aicore__ inline void GetWeightMaskStruct();
__aicore__ inline WeightStruct GetWeightStruct();
__aicore__ inline void GetMaskStruct();
__aicore__ inline void GetInputTensor();
__aicore__ inline void PointBilinearEachChannel(ProcessParam processParam, LocalTensor<float> outValueUb,
LocalTensor<float> outValueTotalLocal, PointParam pointBilinearParam, LocalTensor<T> xLocal);
__aicore__ inline void PointBilinearSetMask(PointParam pointBilinearParam);
private:
TPipe pipe;
TQue<QuePosition::VECIN, 1> gridQueue_;
TBuf<QuePosition::VECCALC> xBuf_;
TBuf<QuePosition::VECCALC> inputXIntBuf_;
TBuf<QuePosition::VECCALC> inputYIntBuf_;
TBuf<QuePosition::VECCALC> inputZIntBuf_;
TBuf<QuePosition::VECCALC> inputXFpBuf_;
TBuf<QuePosition::VECCALC> inputYFpBuf_;
TBuf<QuePosition::VECCALC> inputZFpBuf_;
TBuf<QuePosition::VECCALC> weightBuf_;
TBuf<QuePosition::VECCALC> weightTmpBuf_;
TBuf<QuePosition::VECCALC> weightTmp1Buf_;
TBuf<QuePosition::VECCALC> weightTmp2Buf_;
TBuf<QuePosition::VECCALC> weightTmp3Buf_;
TBuf<QuePosition::VECCALC> coorBuf_;
TBuf<QuePosition::VECCALC> outValueBuf_;
TBuf<QuePosition::VECCALC> outValueSumBuf_;
TBuf<QuePosition::VECCALC> maskBuf2_;
TBuf<QuePosition::VECCALC> maskBuf3_;
TBuf<QuePosition::VECCALC> maskBuf4_;
TBuf<QuePosition::VECCALC> maskBuf5_;
TBuf<QuePosition::VECCALC> maskBuf6_;
TBuf<QuePosition::VECCALC> maskBuf7_;
TBuf<QuePosition::VECCALC> maskBuf8_;
TBuf<QuePosition::VECCALC> weightMaskBuf2_;
TBuf<QuePosition::VECCALC> weightMaskBuf3_;
TBuf<QuePosition::VECCALC> weightMaskBuf4_;
TBuf<QuePosition::VECCALC> weightMaskBuf5_;
TBuf<QuePosition::VECCALC> weightMaskBuf6_;
TBuf<QuePosition::VECCALC> weightMaskBuf7_;
TBuf<QuePosition::VECCALC> weightMaskBuf8_;
TBuf<QuePosition::VECCALC> bufferMaskXBuf_;
TBuf<QuePosition::VECCALC> bufferMaskYBuf_;
TBuf<QuePosition::VECCALC> bufferMaskZBuf_;
TBuf<QuePosition::VECCALC> gridFp16Buf_;
GlobalTensor<T> gmX_;
GlobalTensor<T> gmGrid_;
GlobalTensor<float> gmWorkspace_;
GlobalTensor<T> gmY_;
LocalTensor<uint8_t> weightMaskUb;
LocalTensor<uint8_t> weightMaskUb2;
LocalTensor<uint8_t> weightMaskUb3;
LocalTensor<uint8_t> weightMaskUb4;
LocalTensor<uint8_t> weightMaskUb5;
LocalTensor<uint8_t> weightMaskUb6;
LocalTensor<uint8_t> weightMaskUb7;
LocalTensor<uint8_t> weightMaskUb8;
LocalTensor<uint64_t> weightMaskUbTmp;
LocalTensor<uint64_t> weightMaskUbTmp2;
LocalTensor<uint64_t> weightMaskUbTmp3;
LocalTensor<uint64_t> weightMaskUbTmp4;
LocalTensor<uint64_t> weightMaskUbTmp5;
LocalTensor<uint64_t> weightMaskUbTmp6;
LocalTensor<uint64_t> weightMaskUbTmp7;
LocalTensor<uint64_t> weightMaskUbTmp8;
LocalTensor<uint8_t> maskUb;
LocalTensor<uint64_t> maskUbTmp;
LocalTensor<uint8_t> maskUb2;
LocalTensor<uint64_t> maskUbTmp2;
LocalTensor<uint8_t> maskUb3;
LocalTensor<uint64_t> maskUbTmp3;
LocalTensor<uint8_t> maskUb4;
LocalTensor<uint64_t> maskUbTmp4;
LocalTensor<uint8_t> maskUb5;
LocalTensor<uint64_t> maskUbTmp5;
LocalTensor<uint8_t> maskUb6;
LocalTensor<uint64_t> maskUbTmp6;
LocalTensor<uint8_t> maskUb7;
LocalTensor<uint64_t> maskUbTmp7;
LocalTensor<uint8_t> maskUb8;
LocalTensor<uint64_t> maskUbTmp8;
LocalTensor<int32_t> inputXWIntLocal;
LocalTensor<int32_t> inputXEIntLocal;
LocalTensor<int32_t> inputYWIntLocal;
LocalTensor<int32_t> inputYEIntLocal;
LocalTensor<int32_t> inputZWIntLocal;
LocalTensor<int32_t> inputZEIntLocal;
LocalTensor<float> inputXWFpLocal;
LocalTensor<float> inputXEFpLocal;
LocalTensor<float> inputYWFpLocal;
LocalTensor<float> inputYEFpLocal;
LocalTensor<float> inputZWFpLocal;
LocalTensor<float> inputZEFpLocal;
const int64_t BLOCK_NUM = BLOCK_SIZE / sizeof(T);
const int64_t X_UB_SIZE_4_GENERAL = 32768;
const int64_t GRID_UB_SIZE_4_GENERAL = 6144;
const int64_t GRID_UB_SIZE_4_FP16 = 3072;
const int64_t Y_UB_SIZE_4_GENERAL = 2048;
const int64_t XYZ_UB_SIZE_4_GENERAL = 4096;
const int64_t OUT_VAL_NUM = 4096;
const int64_t X_UB_OFFSET = 512;
const int64_t OUT_FP16_OFFSET = TRANSE_REP_STRIDE * CHANNEL_BLOCK * sizeof(T);
int64_t blockIDX = 0;
GridSampleCommonParam commonParam{};
IndexBuffer indexBuffer{};
uint64_t rsvdCnt = 0;
uint32_t mask = 192;
};
template <typename T>
__aicore__ inline void GridSampler3D<T>::InitBuffer()
{
pipe.InitBuffer(gridQueue_, 1, GRID_UB_SIZE_4_GENERAL);
pipe.InitBuffer(xBuf_, X_UB_SIZE_4_GENERAL);
pipe.InitBuffer(indexBuffer.inputXYZFPBuf_, GRID_UB_SIZE_4_GENERAL);
pipe.InitBuffer(inputXIntBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(inputYIntBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(inputZIntBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(inputXFpBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(inputYFpBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(inputZFpBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(weightBuf_, Y_UB_SIZE_4_GENERAL * 8);
pipe.InitBuffer(weightTmpBuf_, Y_UB_SIZE_4_GENERAL * 4);
pipe.InitBuffer(indexBuffer.intTmpBuf_, Y_UB_SIZE_4_GENERAL);
pipe.InitBuffer(coorBuf_, Y_UB_SIZE_4_GENERAL);
pipe.InitBuffer(indexBuffer.coorTmpBuf_, Y_UB_SIZE_4_GENERAL);
pipe.InitBuffer(outValueBuf_, X_UB_SIZE_4_GENERAL);
pipe.InitBuffer(outValueSumBuf_, X_UB_SIZE_4_GENERAL);
pipe.InitBuffer(indexBuffer.maskBuf_, 960);
pipe.InitBuffer(maskBuf2_, 960);
pipe.InitBuffer(maskBuf3_, 960);
pipe.InitBuffer(maskBuf4_, 960);
pipe.InitBuffer(maskBuf5_, 960);
pipe.InitBuffer(maskBuf6_, 960);
pipe.InitBuffer(maskBuf7_, 960);
pipe.InitBuffer(maskBuf8_, 960);
pipe.InitBuffer(indexBuffer.weightMaskBuf_, 320);
pipe.InitBuffer(weightMaskBuf2_, 320);
pipe.InitBuffer(weightMaskBuf3_, 320);
pipe.InitBuffer(weightMaskBuf4_, 320);
pipe.InitBuffer(weightMaskBuf5_, 320);
pipe.InitBuffer(weightMaskBuf6_, 320);
pipe.InitBuffer(weightMaskBuf7_, 320);
pipe.InitBuffer(weightMaskBuf8_, 320);
pipe.InitBuffer(indexBuffer.modBuf_, Y_UB_SIZE_4_GENERAL);
pipe.InitBuffer(indexBuffer.extraBuf_, Y_UB_SIZE_4_GENERAL);
pipe.InitBuffer(indexBuffer.outTmpBuf_, XYZ_UB_SIZE_4_GENERAL);
pipe.InitBuffer(bufferMaskXBuf_, BLOCK_SIZE * 6);
pipe.InitBuffer(bufferMaskYBuf_, BLOCK_SIZE * 6);
pipe.InitBuffer(bufferMaskZBuf_, BLOCK_SIZE * 6);
if constexpr (IsSameType<T, half>::value || IsSameType<T, bfloat16_t>::value) {
pipe.InitBuffer(gridFp16Buf_, GRID_UB_SIZE_4_FP16);
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::Init(
GM_ADDR x, GM_ADDR gird, GM_ADDR y, GM_ADDR workspace, const GridSampleTilingData *tilingData, TPipe pipeIn)
{
pipe = pipeIn;
blockIDX = GetBlockIdx();
ParseTilingData(tilingData, commonParam);
gmX_.SetGlobalBuffer((__gm__ T *)x);
gmGrid_.SetGlobalBuffer((__gm__ T *)gird);
gmWorkspace_.SetGlobalBuffer((__gm__ float *)workspace);
gmY_.SetGlobalBuffer((__gm__ T *)y);
InitBuffer();
initBufTensor(bufferMaskXBuf_, bufferMaskYBuf_, bufferMaskZBuf_);
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::Clip(
LocalTensor<float> iXFpUb, LocalTensor<float> iYFpUb, LocalTensor<float> iZFpUb)
{
if (commonParam.paddingMode_ == PADDING_MODE_BORDER) {
BorderClip(iXFpUb, iYFpUb, iZFpUb, indexBuffer, commonParam);
} else if (commonParam.paddingMode_ == PADDING_MODE_REFLECTION) {
ReflectClip(iXFpUb, iYFpUb, iZFpUb, indexBuffer, commonParam);
} else if (commonParam.paddingMode_ == PADDING_MODE_ZEROS) {
ZeroClip(iXFpUb, iYFpUb, iZFpUb);
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::ZeroClip(
LocalTensor<float> iXFpUbVal, LocalTensor<float> iYFpUbVal, LocalTensor<float> iZFpUbVal)
{
LocalTensor<uint8_t> maskUbVal = indexBuffer.weightMaskBuf_.Get<uint8_t>(MASK_UB_SIZE);
LocalTensor<float> tmpUbVal = indexBuffer.inputXYZFPBuf_.Get<float>();
Muls(tmpUbVal, iXFpUbVal, (float)(0.0), CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Compare(maskUbVal, tmpUbVal, tmpUbVal, CMPMODE::EQ, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
CoordinatesSelectScalar(iXFpUbVal, iXFpUbVal, maskUbVal, -100.0f, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Muls(tmpUbVal, iYFpUbVal, (float)(0.0), CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Compare(maskUbVal, tmpUbVal, tmpUbVal, CMPMODE::EQ, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
CoordinatesSelectScalar(iYFpUbVal, iYFpUbVal, maskUbVal, -100.0f, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Muls(tmpUbVal, iZFpUbVal, (float)(0.0), CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Compare(maskUbVal, tmpUbVal, tmpUbVal, CMPMODE::EQ, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
CoordinatesSelectScalar(iZFpUbVal, iZFpUbVal, maskUbVal, -100.0f, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::MTE2ForNCHW(
int32_t nIdx, PointParam pointBilinearParam, LocalTensor<int32_t> coorUb, LocalTensor<T> xLocal)
{
for (int32_t i = 0; i < pointBilinearParam.loopElems; i++) {
int64_t coordVal = coorUb.GetValue(pointBilinearParam.loopOffset + i);
int64_t baseLoc = nIdx * commonParam.inputC_ * commonParam.inputH_ * commonParam.inputW_ + coordVal +
pointBilinearParam.cIdx * CHANNEL_BLOCK * commonParam.inputH_ * commonParam.inputW_;
for (int cIter = 0; cIter < pointBilinearParam.channelAlign; cIter++) {
int32_t xLocalOff = i * pointBilinearParam.channelAlign + cIter;
if (cIter >= pointBilinearParam.calCElems) {
if constexpr (IsSameType<T, bfloat16_t>::value) {
xLocal.SetValue(xLocalOff, ToBfloat16(0.0));
} else {
xLocal.SetValue(xLocalOff, static_cast<T>(0.0));
}
continue;
}
int64_t coord = baseLoc + cIter * commonParam.inputH_ * commonParam.inputW_;
xLocal.SetValue(xLocalOff, gmX_.GetValue(coord));
}
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::MTE2ForNHWC(
int32_t nIdx, PointParam pointBilinearParam, LocalTensor<int32_t> coorUb, LocalTensor<T> xLocal)
{
int64_t base = nIdx * commonParam.inputH_ * commonParam.inputW_ * commonParam.inputD_ * commonParam.inputC_ +
pointBilinearParam.cIdx * CHANNEL_BLOCK;
auto timeStep = pointBilinearParam.loopElems / 8;
DataCopyExtParams params;
params.blockCount = 1;
params.blockLen = pointBilinearParam.calCElems * sizeof(T);
params.srcStride = 0;
params.dstStride = 0;
DataCopyPadExtParams<T> padParams{false, 0, 0, 0};
for (int32_t i = 0; i < timeStep; i++) {
int64_t coordVal_0 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8) * commonParam.inputC_;
int64_t coordVal_1 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 1) * commonParam.inputC_;
int64_t coordVal_2 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 2) * commonParam.inputC_;
int64_t coordVal_3 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 3) * commonParam.inputC_;
int64_t coordVal_4 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 4) * commonParam.inputC_;
int64_t coordVal_5 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 5) * commonParam.inputC_;
int64_t coordVal_6 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 6) * commonParam.inputC_;
int64_t coordVal_7 = coorUb.GetValue(pointBilinearParam.loopOffset + i * 8 + 7) * commonParam.inputC_;
int64_t location_0 = base + coordVal_0;
int64_t location_1 = base + coordVal_1;
int64_t location_2 = base + coordVal_2;
int64_t location_3 = base + coordVal_3;
int64_t location_4 = base + coordVal_4;
int64_t location_5 = base + coordVal_5;
int64_t location_6 = base + coordVal_6;
int64_t location_7 = base + coordVal_7;
DataCopyPad(xLocal[(i * 8) * pointBilinearParam.channelAlign], gmX_[location_0], params, padParams);
DataCopyPad(xLocal[(i * 8 + 1) * pointBilinearParam.channelAlign], gmX_[location_1], params, padParams);
DataCopyPad(xLocal[(i * 8 + 2) * pointBilinearParam.channelAlign], gmX_[location_2], params, padParams);
DataCopyPad(xLocal[(i * 8 + 3) * pointBilinearParam.channelAlign], gmX_[location_3], params, padParams);
DataCopyPad(xLocal[(i * 8 + 4) * pointBilinearParam.channelAlign], gmX_[location_4], params, padParams);
DataCopyPad(xLocal[(i * 8 + 5) * pointBilinearParam.channelAlign], gmX_[location_5], params, padParams);
DataCopyPad(xLocal[(i * 8 + 6) * pointBilinearParam.channelAlign], gmX_[location_6], params, padParams);
DataCopyPad(xLocal[(i * 8 + 7) * pointBilinearParam.channelAlign], gmX_[location_7], params, padParams);
}
for (auto i = pointBilinearParam.loopElems / 8 * 8; i < pointBilinearParam.loopElems; i++) {
int64_t coordVal_0 = coorUb.GetValue(pointBilinearParam.loopOffset + i) * commonParam.inputC_;
int64_t location_0 = base + coordVal_0;
DataCopyPad(xLocal[i * pointBilinearParam.channelAlign], gmX_[location_0], params, padParams);
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::OutTransposeFp16(
int32_t channelAlign, LocalTensor<T> xLocal, LocalTensor<T> outValueUb)
{
uint64_t dstList[16];
uint64_t srcList[16];
event_t eventVS = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_S));
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
TransDataTo5HDParams transDataParams;
transDataParams.dstHighHalf = false;
transDataParams.srcHighHalf = false;
if (channelAlign == B16_ALIGN_FACTOR) {
transDataParams.repeatTimes = 8;
transDataParams.dstRepStride = 1;
transDataParams.srcRepStride = 16;
for (int32_t i = 0; i < 16; i++) {
srcList[i] = (uint64_t)(xLocal[i * 16].GetPhyAddr());
}
for (int32_t i = 0; i < 16; i++) {
dstList[i] = (uint64_t)(outValueUb[i * TRANSE_REP_STRIDE].GetPhyAddr());
}
SetFlag<HardEvent::S_V>(eventSV);
WaitFlag<HardEvent::S_V>(eventSV);
TransDataTo5HD<T>(dstList, srcList, transDataParams);
SetFlag<HardEvent::V_S>(eventVS);
WaitFlag<HardEvent::V_S>(eventVS);
} else if (channelAlign <= 64) {
transDataParams.repeatTimes = channelAlign / 16;
transDataParams.dstRepStride = TRANSE_REP_STRIDE;
transDataParams.srcRepStride = 1;
for (int32_t j = 0; j < 8; j++) {
for (int32_t i = 0; i < 16; i++) {
srcList[i] = (uint64_t)(xLocal[i * channelAlign + j * 16 * channelAlign].GetPhyAddr());
}
for (int32_t i = 0; i < 16; i++) {
dstList[i] = (uint64_t)(outValueUb[i * TRANSE_REP_STRIDE + j * 16].GetPhyAddr());
}
SetFlag<HardEvent::S_V>(eventSV);
WaitFlag<HardEvent::S_V>(eventSV);
TransDataTo5HD<T>(dstList, srcList, transDataParams);
SetFlag<HardEvent::V_S>(eventVS);
WaitFlag<HardEvent::V_S>(eventVS);
}
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::MTE3ForNCHW(
ProcessParam processParam, PointParam pointBilinearParam, LocalTensor<T> outValueUb)
{
int64_t gmYBaseOffset = pointBilinearParam.outBaseOffset + pointBilinearParam.loopOffset +
pointBilinearParam.cIdx * CHANNEL_BLOCK * commonParam.gridDHW_;
event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
uint32_t blockLength = pointBilinearParam.loopElems * sizeof(T);
if (pointBilinearParam.calCElems == 1) {
DataCopyPad(gmY_[gmYBaseOffset], outValueUb, {1, blockLength, 0, 0, 0});
} else {
uint32_t srcStride = TRANSE_REP_STRIDE * sizeof(T) / BLOCK_SIZE -
((pointBilinearParam.loopElems * sizeof(T) + BLOCK_SIZE - 1) / BLOCK_SIZE);
uint32_t dstStride = commonParam.gridDHW_ * sizeof(T) - pointBilinearParam.loopElems * sizeof(T);
DataCopyPad(gmY_[gmYBaseOffset],
outValueUb,
{(uint16_t)pointBilinearParam.calCElems, blockLength, srcStride, dstStride, 0});
}
}
template <typename T>
__aicore__ inline CoordinatesStruct GridSampler3D<T>::GetCoordinatesStruct()
{
CoordinatesStruct coordinatesStruct{};
coordinatesStruct.tnwCoordinates = coorBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
coordinatesStruct.tneCoordinates = indexBuffer.inputXYZFPBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
coordinatesStruct.tswCoordinates =
indexBuffer.inputXYZFPBuf_.GetWithOffset<int32_t>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
coordinatesStruct.tsecoordinates =
indexBuffer.inputXYZFPBuf_.GetWithOffset<int32_t>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 8);
coordinatesStruct.bnwCoordinates = indexBuffer.outTmpBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
coordinatesStruct.bneCoordinates =
indexBuffer.outTmpBuf_.GetWithOffset<int32_t>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
coordinatesStruct.bswCoordinates = indexBuffer.modBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
coordinatesStruct.bsecoordinates = indexBuffer.extraBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
return coordinatesStruct;
}
template <typename T>
__aicore__ inline WeightStruct GridSampler3D<T>::GetWeightStruct()
{
WeightStruct weightStruct{};
weightStruct.tnwWeightLocal = weightBuf_.Get<float>(CAL_D_H_W_BLOCK);
weightStruct.tneWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
weightStruct.tswWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 2 * 4);
weightStruct.tseWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 3 * 4);
weightStruct.bnwWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4 * 4);
weightStruct.bneWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 5 * 4);
weightStruct.bswWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 6 * 4);
weightStruct.bseWeightLocal = weightBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 7 * 4);
return weightStruct;
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::GetWeightMaskStruct()
{
weightMaskUb = indexBuffer.weightMaskBuf_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb2 = weightMaskBuf2_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb3 = weightMaskBuf3_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb4 = weightMaskBuf4_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb5 = weightMaskBuf5_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb6 = weightMaskBuf6_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb7 = weightMaskBuf7_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUb8 = weightMaskBuf8_.Get<uint8_t>(MASK_UB_SIZE);
weightMaskUbTmp = weightMaskUb.ReinterpretCast<uint64_t>();
weightMaskUbTmp2 = weightMaskUb2.ReinterpretCast<uint64_t>();
weightMaskUbTmp3 = weightMaskUb3.ReinterpretCast<uint64_t>();
weightMaskUbTmp4 = weightMaskUb4.ReinterpretCast<uint64_t>();
weightMaskUbTmp5 = weightMaskUb5.ReinterpretCast<uint64_t>();
weightMaskUbTmp6 = weightMaskUb6.ReinterpretCast<uint64_t>();
weightMaskUbTmp7 = weightMaskUb7.ReinterpretCast<uint64_t>();
weightMaskUbTmp8 = weightMaskUb8.ReinterpretCast<uint64_t>();
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::GetMaskStruct()
{
maskUb = indexBuffer.maskBuf_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp = maskUb.ReinterpretCast<uint64_t>();
maskUb2 = maskBuf2_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp2 = maskUb2.ReinterpretCast<uint64_t>();
maskUb3 = maskBuf3_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp3 = maskUb3.ReinterpretCast<uint64_t>();
maskUb4 = maskBuf4_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp4 = maskUb4.ReinterpretCast<uint64_t>();
maskUb5 = maskBuf5_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp5 = maskUb5.ReinterpretCast<uint64_t>();
maskUb6 = maskBuf6_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp6 = maskUb6.ReinterpretCast<uint64_t>();
maskUb7 = maskBuf7_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp7 = maskUb7.ReinterpretCast<uint64_t>();
maskUb8 = maskBuf8_.Get<uint8_t>(MASK_UB_SIZE);
maskUbTmp8 = maskUb8.ReinterpretCast<uint64_t>();
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::PointBilinearEachChannel(ProcessParam processParam,
LocalTensor<float> outValueUb, LocalTensor<float> outValueTotalLocal, PointParam pointBilinearParam,
LocalTensor<T> xLocal)
{
CoordinatesStruct coordinatesStruct = GetCoordinatesStruct();
WeightStruct weightStruct = GetWeightStruct();
GetMaskStruct();
pointBilinearParam.calCElems = commonParam.perLoopChannel_;
if (pointBilinearParam.cIdx == commonParam.channelLoop_ - 1) {
pointBilinearParam.calCElems = commonParam.lastLoopChannel_;
}
pointBilinearParam.channelAlign = Ceil(pointBilinearParam.calCElems, B32_ALIGN_FACTOR) * B32_ALIGN_FACTOR;
if constexpr (IsSameType<T, half>::value || IsSameType<T, bfloat16_t>::value) {
pointBilinearParam.channelAlign = Ceil(pointBilinearParam.calCElems, B16_ALIGN_FACTOR) * B16_ALIGN_FACTOR;
}
CalculatePointBilinearParam calculatePointBilinearParam{coordinatesStruct.tnwCoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.tnwWeightLocal,
maskUbTmp,
false};
CalculatePointBilinearParam calculatePointBilinearParam2{coordinatesStruct.tneCoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.tneWeightLocal,
maskUbTmp2,
true};
CalculatePointBilinearParam calculatePointBilinearParam3{coordinatesStruct.tswCoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.tswWeightLocal,
maskUbTmp3,
true};
CalculatePointBilinearParam calculatePointBilinearParam4{coordinatesStruct.tsecoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.tseWeightLocal,
maskUbTmp4,
true};
CalculatePointBilinearParam calculatePointBilinearParam5{coordinatesStruct.bnwCoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.bnwWeightLocal,
maskUbTmp5,
true};
CalculatePointBilinearParam calculatePointBilinearParam6{coordinatesStruct.bneCoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.bneWeightLocal,
maskUbTmp6,
true};
CalculatePointBilinearParam calculatePointBilinearParam7{coordinatesStruct.bswCoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.bswWeightLocal,
maskUbTmp7,
true};
CalculatePointBilinearParam calculatePointBilinearParam8{coordinatesStruct.bsecoordinates,
outValueUb,
outValueTotalLocal,
weightStruct.bseWeightLocal,
maskUbTmp8,
true};
calculatePointBilinear(processParam, calculatePointBilinearParam, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam2, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam3, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam4, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam5, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam6, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam7, pointBilinearParam, xLocal);
calculatePointBilinear(processParam, calculatePointBilinearParam8, pointBilinearParam, xLocal);
if constexpr (IsSameType<T, half>::value || IsSameType<T, bfloat16_t>::value) {
auto outValueUbSumFp16 = outValueTotalLocal.ReinterpretCast<T>();
Cast(outValueUbSumFp16, outValueTotalLocal, RoundMode::CAST_RINT, TRANSE_REP_STRIDE * CHANNEL_BLOCK);
MTE3ForNCHW(processParam, pointBilinearParam, outValueUbSumFp16);
} else {
MTE3ForNCHW(processParam, pointBilinearParam, outValueTotalLocal);
}
event_t eventMte3V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
SetFlag<HardEvent::MTE3_V>(eventMte3V);
WaitFlag<HardEvent::MTE3_V>(eventMte3V);
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::PointBilinearSetMask(PointParam pointBilinearParam)
{
maskUbTmp.SetValue(0, weightMaskUbTmp.GetValue(pointBilinearParam.maskOffset));
maskUbTmp.SetValue(1, weightMaskUbTmp.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp2.SetValue(0, weightMaskUbTmp2.GetValue(pointBilinearParam.maskOffset));
maskUbTmp2.SetValue(1, weightMaskUbTmp2.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp3.SetValue(0, weightMaskUbTmp3.GetValue(pointBilinearParam.maskOffset));
maskUbTmp3.SetValue(1, weightMaskUbTmp3.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp4.SetValue(0, weightMaskUbTmp4.GetValue(pointBilinearParam.maskOffset));
maskUbTmp4.SetValue(1, weightMaskUbTmp4.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp5.SetValue(0, weightMaskUbTmp5.GetValue(pointBilinearParam.maskOffset));
maskUbTmp5.SetValue(1, weightMaskUbTmp5.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp6.SetValue(0, weightMaskUbTmp6.GetValue(pointBilinearParam.maskOffset));
maskUbTmp6.SetValue(1, weightMaskUbTmp6.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp7.SetValue(0, weightMaskUbTmp7.GetValue(pointBilinearParam.maskOffset));
maskUbTmp7.SetValue(1, weightMaskUbTmp7.GetValue(pointBilinearParam.maskOffset + 1));
maskUbTmp8.SetValue(0, weightMaskUbTmp8.GetValue(pointBilinearParam.maskOffset));
maskUbTmp8.SetValue(1, weightMaskUbTmp8.GetValue(pointBilinearParam.maskOffset + 1));
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::PointBilinear(
ProcessParam processParam, LocalTensor<float> outValueUb, bool isAutomicAdd)
{
CoordinatesStruct coordinatesStruct = GetCoordinatesStruct();
GetWeightMaskStruct();
WeightStruct weightStruct = GetWeightStruct();
GetMaskStruct();
LocalTensor<float> outValueTotalLocal = outValueSumBuf_.Get<float>();
if (commonParam.paddingMode_ == PADDING_MODE_ZEROS) {
CoordinatesSelectScalar(
weightStruct.tnwWeightLocal, weightStruct.tnwWeightLocal, weightMaskUb, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.tneWeightLocal, weightStruct.tneWeightLocal, weightMaskUb2, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.tswWeightLocal, weightStruct.tswWeightLocal, weightMaskUb3, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.tseWeightLocal, weightStruct.tseWeightLocal, weightMaskUb4, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.bnwWeightLocal, weightStruct.bnwWeightLocal, weightMaskUb5, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.bneWeightLocal, weightStruct.bneWeightLocal, weightMaskUb6, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.bswWeightLocal, weightStruct.bswWeightLocal, weightMaskUb7, 0.0f, CAL_D_H_W_BLOCK);
CoordinatesSelectScalar(
weightStruct.bseWeightLocal, weightStruct.bseWeightLocal, weightMaskUb8, 0.0f, CAL_D_H_W_BLOCK);
}
PointParam pointBilinearParam{};
int32_t trans_loop = (processParam.calDHWElems + TRANSE_REP_STRIDE - 1) / TRANSE_REP_STRIDE;
pointBilinearParam.loopElems = TRANSE_REP_STRIDE;
pointBilinearParam.outBaseOffset =
processParam.nIdx * commonParam.gridDHW_ * commonParam.inputC_ + processParam.hwIdx * CAL_D_H_W_BLOCK;
for (int32_t loop_idx = 0; loop_idx < trans_loop; loop_idx++) {
if (loop_idx == trans_loop - 1) {
pointBilinearParam.loopElems = processParam.calDHWElems - TRANSE_REP_STRIDE * (trans_loop - 1);
}
pointBilinearParam.loopOffset = loop_idx * TRANSE_REP_STRIDE;
pointBilinearParam.maskOffset = loop_idx * 2;
event_t eventSV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::S_V));
SetFlag<HardEvent::S_V>(eventSV);
WaitFlag<HardEvent::S_V>(eventSV);
PointBilinearSetMask(pointBilinearParam);
LocalTensor<T> xLocal = xBuf_.Get<T>();
if (IsSameType<T, bfloat16_t>::value) {
xLocal = xLocal[TRANSE_REP_STRIDE * CHANNEL_BLOCK];
}
for (pointBilinearParam.cIdx = 0; pointBilinearParam.cIdx < commonParam.channelLoop_;
pointBilinearParam.cIdx++) {
PointBilinearEachChannel(processParam, outValueUb, outValueTotalLocal, pointBilinearParam, xLocal);
}
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::GetXLocal(ProcessParam processParam,
CalculatePointBilinearParam calculatePointBilinearParam, PointParam pointBilinearParam, LocalTensor<T> xLocal,
LocalTensor<float> outValueLocal)
{
if (commonParam.channelLast_ == LAYOUT_NHWC) {
MTE2ForNHWC(processParam.nIdx, pointBilinearParam, calculatePointBilinearParam.coordinatesUb, xLocal);
} else {
MTE2ForNCHW(processParam.nIdx, pointBilinearParam, calculatePointBilinearParam.coordinatesUb, xLocal);
}
event_t eventMte2V = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventMte2V);
WaitFlag<HardEvent::MTE2_V>(eventMte2V);
if constexpr (IsSameType<T, half>::value) {
LocalTensor<T> outValueFp16Ub =
outValueBuf_.GetWithOffset<T>(commonParam.perLoopChannel_ * (TRANSE_REP_STRIDE), OUT_FP16_OFFSET);
OutTransposeFp16(pointBilinearParam.channelAlign, xLocal, outValueFp16Ub);
PipeBarrier<PIPE_V>();
Cast(outValueLocal, outValueFp16Ub, RoundMode::CAST_NONE, pointBilinearParam.calCElems * TRANSE_REP_STRIDE);
} else if constexpr (IsSameType<T, bfloat16_t>::value) {
LocalTensor xLocalFp32 = xBuf_.Get<float>();
Cast(xLocalFp32, xLocal, RoundMode::CAST_NONE, pointBilinearParam.channelAlign * TRANSE_REP_STRIDE);
OutTransposeFp32(pointBilinearParam.channelAlign, xLocalFp32, outValueLocal);
} else {
OutTransposeFp32(pointBilinearParam.channelAlign, xLocal, outValueLocal);
}
PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::calculatePointBilinear(ProcessParam processParam,
CalculatePointBilinearParam calculatePointBilinearParam, PointParam pointBilinearParam, LocalTensor<T> xLocal)
{
auto outValueLocal = calculatePointBilinearParam.outValueTotalLocal;
if (calculatePointBilinearParam.isAtomicAdd) {
outValueLocal = calculatePointBilinearParam.outValueUb;
}
GetXLocal(processParam, calculatePointBilinearParam, pointBilinearParam, xLocal, outValueLocal);
for (size_t i = 0; i < pointBilinearParam.calCElems; i++) {
int32_t ubOffset = i * TRANSE_REP_STRIDE;
Select(outValueLocal[ubOffset],
calculatePointBilinearParam.maskUbTmp,
outValueLocal[ubOffset],
0.0f,
SELMODE::VSEL_TENSOR_SCALAR_MODE,
TRANSE_REP_STRIDE);
}
PipeBarrier<PIPE_V>();
if (pointBilinearParam.calCElems == 1) {
Mul(outValueLocal,
outValueLocal,
calculatePointBilinearParam.weightUb[pointBilinearParam.loopOffset],
TRANSE_REP_STRIDE);
} else {
for (int32_t i = 0; i < TRANSE_MUL_WEGHT_LOOPS; i++) {
int32_t outOffset = i * B32_MASK;
int32_t weightOffset = pointBilinearParam.loopOffset + i * B32_MASK;
Mul(outValueLocal[outOffset],
outValueLocal[outOffset],
calculatePointBilinearParam.weightUb[weightOffset],
B32_MASK,
pointBilinearParam.calCElems,
{1, 1, 1, 16, 16, 0});
}
}
if (calculatePointBilinearParam.isAtomicAdd) {
Add(calculatePointBilinearParam.outValueTotalLocal,
calculatePointBilinearParam.outValueTotalLocal,
outValueLocal,
pointBilinearParam.calCElems * TRANSE_REP_STRIDE);
}
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::CalculateGrid(ProcessParam processParam, LocalTensor<float> gridFp32Local,
LocalTensor<float> inputXFpLocal, LocalTensor<float> inputYFpLocal, LocalTensor<float> inputZFpLocal)
{
int64_t gridGmOffset = processParam.nIdx * commonParam.gridDHW_ * 3 + processParam.hwIdx * CAL_D_H_W_BLOCK * 3;
DataCopyExtParams paramsGrid;
paramsGrid.blockCount = 1;
paramsGrid.blockLen = processParam.calDHWElems * 3 * sizeof(T);
paramsGrid.srcStride = 0;
paramsGrid.dstStride = 0;
DataCopyPadExtParams<T> padParamsGrid{false, 0, 0, 0};
if constexpr (IsSameType<T, half>::value || IsSameType<T, bfloat16_t>::value) {
LocalTensor<T> gridHalfLocal = gridFp16Buf_.Get<T>();
DataCopyPad(gridHalfLocal, gmGrid_[gridGmOffset], paramsGrid, padParamsGrid);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
Cast(gridFp32Local, gridHalfLocal, RoundMode::CAST_NONE, CAL_D_H_W_BLOCK * 3);
PipeBarrier<PIPE_V>();
} else {
DataCopyPad(gridFp32Local, gmGrid_[gridGmOffset], paramsGrid, padParamsGrid);
event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
}
LocalTensor<float> inputXYZUb = indexBuffer.inputXYZFPBuf_.Get<float>();
Adds(inputXYZUb, gridFp32Local, (float)1.0, CAL_D_H_W_BLOCK * 3);
uint16_t repeatTime = CAL_D_H_W_BLOCK * 3 / 192;
LocalTensor<uint32_t> bufZPattern = bufferMaskZBuf_.Get<uint32_t>();
LocalTensor<uint32_t> bufXPattern = bufferMaskXBuf_.Get<uint32_t>();
LocalTensor<uint32_t> bufYPattern = bufferMaskYBuf_.Get<uint32_t>();
PipeBarrier<PIPE_V>();
GatherMask(inputZFpLocal, inputXYZUb, bufZPattern, true, 192, {1, repeatTime, 24, 0}, rsvdCnt);
GatherMask(inputYFpLocal, inputXYZUb, bufYPattern, true, 192, {1, repeatTime, 24, 0}, rsvdCnt);
GatherMask(inputXFpLocal, inputXYZUb, bufXPattern, true, 192, {1, repeatTime, 24, 0}, rsvdCnt);
PipeBarrier<PIPE_V>();
if (commonParam.alignCorners_ == 1) {
Muls(inputZFpLocal, inputZFpLocal, (float)((float)0.5 * (commonParam.inputD_ - (float)1.0)), CAL_D_H_W_BLOCK);
Muls(inputYFpLocal, inputYFpLocal, (float)((float)0.5 * (commonParam.inputH_ - (float)1.0)), CAL_D_H_W_BLOCK);
Muls(inputXFpLocal, inputXFpLocal, (float)((float)0.5 * (commonParam.inputW_ - (float)1.0)), CAL_D_H_W_BLOCK);
} else {
Muls(inputZFpLocal, inputZFpLocal, (float)((float)0.5 * commonParam.inputD_), CAL_D_H_W_BLOCK);
Muls(inputYFpLocal, inputYFpLocal, (float)((float)0.5 * commonParam.inputH_), CAL_D_H_W_BLOCK);
Muls(inputXFpLocal, inputXFpLocal, (float)((float)0.5 * commonParam.inputW_), CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Adds(inputXFpLocal, inputXFpLocal, (float)(-0.5), CAL_D_H_W_BLOCK * 3);
}
PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::GetInputTensor()
{
inputXWIntLocal = inputXIntBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
inputXEIntLocal = inputXIntBuf_.GetWithOffset<int32_t>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
inputYWIntLocal = inputYIntBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
inputYEIntLocal = inputYIntBuf_.GetWithOffset<int32_t>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
inputZWIntLocal = inputZIntBuf_.Get<int32_t>(CAL_D_H_W_BLOCK);
inputZEIntLocal = inputZIntBuf_.GetWithOffset<int32_t>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
inputXWFpLocal = inputXFpBuf_.Get<float>(CAL_D_H_W_BLOCK);
inputXEFpLocal = inputXFpBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
inputYWFpLocal = inputYFpBuf_.Get<float>(CAL_D_H_W_BLOCK);
inputYEFpLocal = inputYFpBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
inputZWFpLocal = inputZFpBuf_.Get<float>(CAL_D_H_W_BLOCK);
inputZEFpLocal = inputZFpBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::calculateGridWeight(
LocalTensor<float> inputXFpLocal, LocalTensor<float> inputYFpLocal, LocalTensor<float> inputZFpLocal)
{
WeightStruct weightStruct = GetWeightStruct();
GetInputTensor();
LocalTensor<float> weightTmpLocal = weightTmpBuf_.Get<float>(CAL_D_H_W_BLOCK);
LocalTensor<float> weightTmp1Local = weightTmpBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 4);
LocalTensor<float> weightTmp2Local = weightTmpBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 2 * 4);
LocalTensor<float> weightTmp3Local = weightTmpBuf_.GetWithOffset<float>(CAL_D_H_W_BLOCK, CAL_D_H_W_BLOCK * 3 * 4);
Sub(weightStruct.tnwWeightLocal, inputXEFpLocal, inputXFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.tneWeightLocal, inputXFpLocal, inputXWFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.tswWeightLocal, inputXEFpLocal, inputXFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.tseWeightLocal, inputXFpLocal, inputXWFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.bnwWeightLocal, inputXEFpLocal, inputXFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.bneWeightLocal, inputXFpLocal, inputXWFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.bswWeightLocal, inputXEFpLocal, inputXFpLocal, CAL_D_H_W_BLOCK);
Sub(weightStruct.bseWeightLocal, inputXFpLocal, inputXWFpLocal, CAL_D_H_W_BLOCK);
Sub(weightTmpLocal, inputYEFpLocal, inputYFpLocal, CAL_D_H_W_BLOCK);
Sub(weightTmp1Local, inputYFpLocal, inputYWFpLocal, CAL_D_H_W_BLOCK);
Sub(weightTmp2Local, inputZEFpLocal, inputZFpLocal, CAL_D_H_W_BLOCK);
Sub(weightTmp3Local, inputZFpLocal, inputZWFpLocal, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Mul(weightStruct.tnwWeightLocal, weightStruct.tnwWeightLocal, weightTmpLocal, CAL_D_H_W_BLOCK);
Mul(weightStruct.tneWeightLocal, weightStruct.tneWeightLocal, weightTmpLocal, CAL_D_H_W_BLOCK);
Mul(weightStruct.tswWeightLocal, weightStruct.tswWeightLocal, weightTmp1Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.tseWeightLocal, weightStruct.tseWeightLocal, weightTmp1Local, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Mul(weightStruct.tnwWeightLocal, weightStruct.tnwWeightLocal, weightTmp2Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.tneWeightLocal, weightStruct.tneWeightLocal, weightTmp2Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.tswWeightLocal, weightStruct.tswWeightLocal, weightTmp2Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.tseWeightLocal, weightStruct.tseWeightLocal, weightTmp2Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.bnwWeightLocal, weightStruct.bnwWeightLocal, weightTmpLocal, CAL_D_H_W_BLOCK);
Mul(weightStruct.bneWeightLocal, weightStruct.bneWeightLocal, weightTmpLocal, CAL_D_H_W_BLOCK);
Mul(weightStruct.bswWeightLocal, weightStruct.bswWeightLocal, weightTmp1Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.bseWeightLocal, weightStruct.bseWeightLocal, weightTmp1Local, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Mul(weightStruct.bnwWeightLocal, weightStruct.bnwWeightLocal, weightTmp3Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.bneWeightLocal, weightStruct.bneWeightLocal, weightTmp3Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.bswWeightLocal, weightStruct.bswWeightLocal, weightTmp3Local, CAL_D_H_W_BLOCK);
Mul(weightStruct.bseWeightLocal, weightStruct.bseWeightLocal, weightTmp3Local, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::clipAllCoordinate(LocalTensor<float> outValueLocal)
{
GetInputTensor();
CoordinatesStruct CoordinatesStruct = GetCoordinatesStruct();
GetWeightMaskStruct();
InputTensorStruct inputTensorStruct{
inputXWFpLocal, inputYWFpLocal, inputZWFpLocal, inputXWIntLocal, inputYWIntLocal, inputZWIntLocal};
InputTensorStruct inputTensorStruct2{
inputXEFpLocal, inputYWFpLocal, inputZWFpLocal, inputXEIntLocal, inputYWIntLocal, inputZWIntLocal};
InputTensorStruct inputTensorStruct3{
inputXWFpLocal, inputYEFpLocal, inputZWFpLocal, inputXWIntLocal, inputYEIntLocal, inputZWIntLocal};
InputTensorStruct inputTensorStruct4{
inputXEFpLocal, inputYEFpLocal, inputZWFpLocal, inputXEIntLocal, inputYEIntLocal, inputZWIntLocal};
InputTensorStruct inputTensorStruct5{
inputXWFpLocal, inputYWFpLocal, inputZEFpLocal, inputXWIntLocal, inputYWIntLocal, inputZEIntLocal};
InputTensorStruct inputTensorStruct6{
inputXEFpLocal, inputYWFpLocal, inputZEFpLocal, inputXEIntLocal, inputYWIntLocal, inputZEIntLocal};
InputTensorStruct inputTensorStruct7{
inputXWFpLocal, inputYEFpLocal, inputZEFpLocal, inputXWIntLocal, inputYEIntLocal, inputZEIntLocal};
InputTensorStruct inputTensorStruct8{
inputXEFpLocal, inputYEFpLocal, inputZEFpLocal, inputXEIntLocal, inputYEIntLocal, inputZEIntLocal};
ClipCoordinates(inputTensorStruct, CoordinatesStruct.tnwCoordinates, weightMaskUb, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct2, CoordinatesStruct.tneCoordinates, weightMaskUb2, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct3, CoordinatesStruct.tswCoordinates, weightMaskUb3, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct4, CoordinatesStruct.tsecoordinates, weightMaskUb4, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct5, CoordinatesStruct.bnwCoordinates, weightMaskUb5, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct6, CoordinatesStruct.bneCoordinates, weightMaskUb6, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct7, CoordinatesStruct.bswCoordinates, weightMaskUb7, indexBuffer, commonParam);
ClipCoordinates(inputTensorStruct8, CoordinatesStruct.bsecoordinates, weightMaskUb8, indexBuffer, commonParam);
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::PerLoopCompute(ProcessParam processParam)
{
LocalTensor<float> gridFp32Local = gridQueue_.AllocTensor<float>();
LocalTensor<float> inputXFpLocal = gridFp32Local;
LocalTensor<float> inputYFpLocal = gridFp32Local[CAL_D_H_W_BLOCK];
LocalTensor<float> inputZFpLocal = gridFp32Local[CAL_D_H_W_BLOCK * 2];
CalculateGrid(processParam, gridFp32Local, inputXFpLocal, inputYFpLocal, inputZFpLocal);
Clip(inputXFpLocal, inputYFpLocal, inputZFpLocal);
GetInputTensor();
Cast(inputXWIntLocal, inputXFpLocal, RoundMode::CAST_FLOOR, CAL_D_H_W_BLOCK);
Cast(inputYWIntLocal, inputYFpLocal, RoundMode::CAST_FLOOR, CAL_D_H_W_BLOCK);
Cast(inputZWIntLocal, inputZFpLocal, RoundMode::CAST_FLOOR, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Cast(inputXWFpLocal, inputXWIntLocal, RoundMode::CAST_NONE, CAL_D_H_W_BLOCK);
Cast(inputYWFpLocal, inputYWIntLocal, RoundMode::CAST_NONE, CAL_D_H_W_BLOCK);
Cast(inputZWFpLocal, inputZWIntLocal, RoundMode::CAST_NONE, CAL_D_H_W_BLOCK);
Adds(inputXEIntLocal, inputXWIntLocal, 1, CAL_D_H_W_BLOCK);
Adds(inputYEIntLocal, inputYWIntLocal, 1, CAL_D_H_W_BLOCK);
Adds(inputZEIntLocal, inputZWIntLocal, 1, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
Adds(inputXEFpLocal, inputXWFpLocal, (float)1.0, CAL_D_H_W_BLOCK);
Adds(inputYEFpLocal, inputYWFpLocal, (float)1.0, CAL_D_H_W_BLOCK);
Adds(inputZEFpLocal, inputZWFpLocal, (float)1.0, CAL_D_H_W_BLOCK);
PipeBarrier<PIPE_V>();
calculateGridWeight(inputXFpLocal, inputYFpLocal, inputZFpLocal);
LocalTensor<float> outValueLocal = outValueBuf_.Get<float>();
clipAllCoordinate(outValueLocal);
PointBilinear(processParam, outValueLocal, true);
gridQueue_.FreeTensor(gridFp32Local);
}
template <typename T>
__aicore__ inline void GridSampler3D<T>::Process()
{
if (blockIDX >= commonParam.needCoreNum_) {
return;
}
ProcessParam processParam;
int32_t preLoopNum = blockIDX * commonParam.preCoreLoop_;
int64_t loopSize = commonParam.preCoreLoop_;
if (blockIDX == commonParam.needCoreNum_ - 1) {
loopSize = commonParam.lastCoreLoop_;
}
for (int32_t loopIdx = 0; loopIdx < loopSize; loopIdx++) {
processParam.nIdx = (preLoopNum + loopIdx) / commonParam.preNUbLoop_;
processParam.hwIdx = (preLoopNum + loopIdx) % commonParam.preNUbLoop_;
processParam.calDHWElems = CAL_D_H_W_BLOCK;
if (processParam.hwIdx == commonParam.preNUbLoop_ - 1) {
processParam.calDHWElems = commonParam.lastLoopDHW_;
}
PerLoopCompute(processParam);
}
}
}
#endif