#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "sparse_conv3d_grad_simt.h"
using namespace AscendC;
namespace {
constexpr int32_t BLOCK_SIZE = 32;
constexpr int32_t INT32_SIZE = 4;
constexpr int32_t INT32_BLOCK_NUM = BLOCK_SIZE / INT32_SIZE;
constexpr int32_t INT8_PER_LOOP = 256;
constexpr MatmulConfig SPARSE_CONV3D_CFG = GetNormalConfig();
}
template<typename T>
class SparseConv3dGrad {
public:
using weightMatType = matmul::MatmulType<TPosition::GM, CubeFormat::ND, T, true>;
using imgToColMatType = matmul::MatmulType<TPosition::GM, CubeFormat::ND, T, true>;
using gradOutMatType = matmul::MatmulType<TPosition::GM, CubeFormat::ND, T>;
using weightGradMatType = matmul::MatmulType<TPosition::GM, CubeFormat::ND, T>;
using featureGradMatType = matmul::MatmulType<TPosition::GM, CubeFormat::ND, T>;
matmul::MatmulImpl<gradOutMatType, weightMatType, featureGradMatType, featureGradMatType, SPARSE_CONV3D_CFG>
featureMatmul_;
matmul::MatmulImpl<imgToColMatType, gradOutMatType, weightGradMatType, weightGradMatType, SPARSE_CONV3D_CFG>
weightMatmul_;
__aicore__ inline SparseConv3dGrad() {};
__aicore__ inline void Init(TPipe* pipe, GM_ADDR features, GM_ADDR weight, GM_ADDR grad_out_features,
GM_ADDR former_sorted_indices, GM_ADDR indices_offset, GM_ADDR features_grad, GM_ADDR weight_grad,
GM_ADDR usrWorkspace, SparseConv3dGradTilingData* tilingData)
{
pipe_ = pipe;
InitTiling(tilingData);
InitBuffer(features, weight, grad_out_features, former_sorted_indices, indices_offset, features_grad,
weight_grad, usrWorkspace);
}
__aicore__ inline void Process();
protected:
TPipe* pipe_;
GlobalTensor<T> featuresGM_, weightGM_, gradOutFeaturesGM_, featuresGradGM_, weightGradGM_;
GlobalTensor<T> featureWsp_, gradOutWsp_, tempGradFeatureWsp_;
GlobalTensor<int32_t> sortedIndicesGM_, indicesOffsetGM_;
GlobalTensor<int32_t> inputIdxWsp_, outIdxWsp_, sparseNumWsp_;
GlobalTensor<uint8_t> kIdxWsp_;
TBuf<TPosition::VECCALC> tmpInputBuf_, tmpGradFeaturesBuf_, indexBuf_, kIdxBuf_;
LocalTensor<T> tmpGradOutLocal_, tmpFeaturesLocal_;
LocalTensor<T> tmpGradFeaturesLocal_;
LocalTensor<int32_t> inputIdxLocal_, outIdxLocal_, inputIdxPtrLocal_, idxInfoLocal_;
LocalTensor<uint8_t> kIdxLocal_, maskLocal_;
uint32_t blockIdx_, aicNum_, aivNum_, featureByteSize_, blockDataNum_, loopPointCount_;
uint32_t usedVectorNum_, kernelSize_, inChannels_, outChannels_, totalTaskNum_, sparseRatio_, ubMaxTaskNum_;
int32_t totalPointsCount_, startOffset_;
uint64_t featureWspOffset_;
uint32_t sparseWspOffset_;
private:
__aicore__ inline void InitTiling(SparseConv3dGradTilingData* tilingData);
__aicore__ inline void InitBuffer(GM_ADDR features, GM_ADDR weight, GM_ADDR grad_out_features,
GM_ADDR former_sorted_indices, GM_ADDR indices_offset, GM_ADDR features_grad, GM_ADDR weight_grad,
GM_ADDR usrWorkspace);
__aicore__ inline void CopyInHashMap(uint32_t startOffset, int32_t pointCount);
__aicore__ inline void CopyInFeature(uint32_t bitLoopNum, const int32_t pointCount, int32_t& totalSparseM);
__aicore__ inline void CalGradFeaturesMatmul(uint8_t k, int32_t sparseM, uint8_t subBlockIdx);
__aicore__ inline void CalGradWeightMatmul(uint8_t k, int32_t sparseM, uint8_t subBlockIdx);
__aicore__ inline void GradFeaturesScatterAdd(uint32_t sparseM);
};
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::InitTiling(SparseConv3dGradTilingData* tilingData)
{
blockIdx_ = GetBlockIdx();
aicNum_ = GetBlockNum();
aivNum_ = aicNum_ * 2;
featureByteSize_ = sizeof(T);
blockDataNum_ = BLOCK_SIZE / featureByteSize_;
usedVectorNum_ = tilingData->usedVectorNum;
kernelSize_ = tilingData->kernelSize;
inChannels_ = tilingData->inChannels;
outChannels_ = tilingData->outChannels;
totalTaskNum_ = tilingData->totalTaskNum;
totalPointsCount_ = tilingData->totalPointsCount;
startOffset_ = tilingData->startOffset;
ubMaxTaskNum_ = tilingData->ubMaxTaskNum;
loopPointCount_ = tilingData->loopPointCount;
featureWspOffset_ = tilingData->featureWspOffset;
sparseWspOffset_ = tilingData->sparseWspOffset;
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::InitBuffer(GM_ADDR features, GM_ADDR weight, GM_ADDR grad_out_features,
GM_ADDR former_sorted_indices, GM_ADDR indices_offset, GM_ADDR features_grad, GM_ADDR weight_grad,
GM_ADDR usrWorkspace)
{
featuresGM_.SetGlobalBuffer((__gm__ T*)features);
weightGM_.SetGlobalBuffer((__gm__ T*)weight);
gradOutFeaturesGM_.SetGlobalBuffer((__gm__ T*)grad_out_features);
sortedIndicesGM_.SetGlobalBuffer((__gm__ int32_t*)former_sorted_indices);
indicesOffsetGM_.SetGlobalBuffer((__gm__ int32_t*)indices_offset);
featuresGradGM_.SetGlobalBuffer((__gm__ T*)features_grad);
weightGradGM_.SetGlobalBuffer((__gm__ T*)weight_grad);
pipe_->InitBuffer(tmpInputBuf_, ubMaxTaskNum_ * (outChannels_ + inChannels_) * featureByteSize_);
pipe_->InitBuffer(tmpGradFeaturesBuf_, ubMaxTaskNum_ * inChannels_ * featureByteSize_);
pipe_->InitBuffer(indexBuf_, 3 * loopPointCount_ * INT32_SIZE + BLOCK_SIZE);
tmpGradOutLocal_ = tmpInputBuf_.Get<T>();
tmpFeaturesLocal_ = tmpGradOutLocal_[ubMaxTaskNum_ * outChannels_];
tmpGradFeaturesLocal_ = tmpGradFeaturesBuf_.Get<T>();
inputIdxLocal_ = indexBuf_.Get<int32_t>();
outIdxLocal_ = inputIdxLocal_[loopPointCount_];
inputIdxPtrLocal_ = outIdxLocal_[loopPointCount_];
idxInfoLocal_ = inputIdxPtrLocal_[loopPointCount_];
uint32_t maskAlignLen_ = AlignUp(loopPointCount_, INT8_PER_LOOP);
pipe_->InitBuffer(kIdxBuf_, maskAlignLen_ + maskAlignLen_ / 8);
kIdxLocal_ = kIdxBuf_.Get<uint8_t>();
maskLocal_ = kIdxLocal_[maskAlignLen_];
if ASCEND_IS_AIV {
featureWsp_.SetGlobalBuffer((__gm__ T*)(usrWorkspace) + blockIdx_ * featureWspOffset_);
}
if ASCEND_IS_AIC {
featureWsp_.SetGlobalBuffer((__gm__ T*)(usrWorkspace) + blockIdx_ * 2 * featureWspOffset_);
}
tempGradFeatureWsp_ = featureWsp_[loopPointCount_ * inChannels_];
gradOutWsp_ = tempGradFeatureWsp_[loopPointCount_ * inChannels_];
inputIdxWsp_.SetGlobalBuffer(
(__gm__ int32_t*)(usrWorkspace) + usedVectorNum_ * featureWspOffset_);
outIdxWsp_ = inputIdxWsp_[totalPointsCount_];
sparseNumWsp_ = outIdxWsp_[totalPointsCount_];
kIdxWsp_ = sparseNumWsp_[sparseWspOffset_].template ReinterpretCast<uint8_t>();
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::Process()
{
if ASCEND_IS_AIV {
Simt::VF_CALL<PrepareGlobalHashMap>(
Simt::Dim3 {THREAD_NUM},
(__gm__ int32_t*)sortedIndicesGM_.GetPhyAddr(),
(__gm__ int32_t*)indicesOffsetGM_.GetPhyAddr(),
(__gm__ int32_t*)inputIdxWsp_.GetPhyAddr(),
(__gm__ int32_t*)outIdxWsp_.GetPhyAddr(),
(__gm__ uint8_t*)kIdxWsp_.GetPhyAddr(),
totalTaskNum_,
kernelSize_,
startOffset_);
SyncAll();
for (int32_t pointIdx = blockIdx_ * loopPointCount_; pointIdx < totalPointsCount_;
pointIdx += loopPointCount_ * usedVectorNum_) {
int32_t ubPointCount = min(totalPointsCount_ - pointIdx, (int32_t)(loopPointCount_));
CopyInHashMap(pointIdx, ubPointCount);
SetFlag<HardEvent::MTE2_V>(0);
WaitFlag<HardEvent::MTE2_V>(0);
uint32_t maskAlign = AlignUp(ubPointCount, INT8_PER_LOOP);
uint32_t bitLoopCount = DivCeil(ubPointCount, 64);
for (uint8_t kIdx = 0; kIdx < kernelSize_; ++kIdx) {
uint16_t flagId = kIdx % 8;
CompareScalar(maskLocal_, kIdxLocal_, kIdx, CMPMODE::EQ, maskAlign);
SetFlag<HardEvent::V_S>(0);
WaitFlag<HardEvent::V_S>(0);
int32_t totalSparseNum = 0;
CopyInFeature(bitLoopCount, ubPointCount, totalSparseNum);
idxInfoLocal_.SetValue(0, totalSparseNum);
SetFlag<HardEvent::S_MTE3>(flagId);
WaitFlag<HardEvent::S_MTE3>(flagId);
DataCopyPad(sparseNumWsp_[blockIdx_ * kernelSize_ + kIdx], idxInfoLocal_,
{static_cast<uint16_t>(1), static_cast<uint32_t>(1 * INT32_SIZE), 0, 0, 0});
CrossCoreSetFlag<0x4, PIPE_MTE3>(flagId);
CrossCoreWaitFlag<0x4>(flagId);
GradFeaturesScatterAdd(totalSparseNum);
}
}
}
if ASCEND_IS_AIC {
for (uint32_t aicTaskOffset = 2 * loopPointCount_ * blockIdx_; aicTaskOffset < totalPointsCount_;
aicTaskOffset += 2 * loopPointCount_ * aicNum_) {
bool sub1Used = (totalPointsCount_ - aicTaskOffset > loopPointCount_) ? true : false;
for (uint8_t kIdx = 0; kIdx < kernelSize_; ++kIdx) {
uint16_t sub0FlagId = kIdx % 8;
uint16_t sub1FlagId = sub0FlagId + 16;
CrossCoreWaitFlag<0x4>(sub0FlagId);
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_OUT>(
sparseNumWsp_[2 * blockIdx_ * kernelSize_ + kIdx]);
int32_t sub0SparseNum = sparseNumWsp_.GetValue(2 * blockIdx_ * kernelSize_ + kIdx);
CalGradFeaturesMatmul(kIdx, sub0SparseNum, 0);
CalGradWeightMatmul(kIdx, sub0SparseNum, 0);
CrossCoreSetFlag<0x4, PIPE_FIX>(sub0FlagId);
if (sub1Used) {
CrossCoreWaitFlag<0x4>(sub1FlagId);
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, AscendC::DcciDst::CACHELINE_OUT>(
sparseNumWsp_[(2 * blockIdx_ + 1) * kernelSize_ + kIdx]);
int32_t sub1SparseNum = sparseNumWsp_.GetValue((2 * blockIdx_ + 1) * kernelSize_ + kIdx);
CalGradFeaturesMatmul(kIdx, sub1SparseNum, 1);
CalGradWeightMatmul(kIdx, sub1SparseNum, 1);
CrossCoreSetFlag<0x4, PIPE_FIX>(sub1FlagId);
}
}
}
}
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::CopyInHashMap(uint32_t startOffset, int32_t pointCount)
{
uint32_t moveAlign = (pointCount == loopPointCount_) ? loopPointCount_ : AlignUp(pointCount, INT32_BLOCK_NUM);
uint32_t kIdxAlign = AlignUp(moveAlign, INT8_PER_LOOP);
DataCopy(kIdxLocal_, kIdxWsp_[startOffset], kIdxAlign);
DataCopy(inputIdxLocal_, inputIdxWsp_[startOffset], moveAlign);
DataCopy(outIdxLocal_, outIdxWsp_[startOffset], moveAlign);
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::CopyInFeature(
uint32_t bitLoopNum, const int32_t pointCount, int32_t& totalSparseM)
{
int32_t sparseM = 0;
for (uint32_t outerIdx = 0; outerIdx < bitLoopNum; ++outerIdx) {
uint64_t maskValue = maskLocal_.ReinterpretCast<uint64_t>().GetValue(outerIdx);
uint32_t innerLoopCount = outerIdx == (bitLoopNum - 1) ? (pointCount - 64 * outerIdx) : 64;
for (uint32_t innerIdx = ScalarGetSFFValue<1>(maskValue); innerIdx < innerLoopCount && innerIdx >= 0;
innerIdx = ScalarGetSFFValue<1>(maskValue)) {
maskValue = sbitset0(maskValue, innerIdx);
uint32_t bitIdx = outerIdx * 64 + innerIdx;
uint32_t inputIdx = inputIdxLocal_.GetValue(bitIdx);
uint32_t outIdx = outIdxLocal_.GetValue(bitIdx);
inputIdxPtrLocal_.SetValue(totalSparseM + sparseM, inputIdx);
DataCopy(tmpGradOutLocal_[sparseM * outChannels_], gradOutFeaturesGM_[outIdx * outChannels_], outChannels_);
DataCopy(tmpFeaturesLocal_[sparseM * inChannels_], featuresGM_[inputIdx * inChannels_], inChannels_);
sparseM++;
if (sparseM == ubMaxTaskNum_) {
SetFlag<HardEvent::MTE2_MTE3>(0);
WaitFlag<HardEvent::MTE2_MTE3>(0);
DataCopy(gradOutWsp_[totalSparseM * outChannels_], tmpGradOutLocal_, sparseM * outChannels_);
DataCopy(featureWsp_[totalSparseM * inChannels_], tmpFeaturesLocal_, sparseM * inChannels_);
totalSparseM += sparseM;
sparseM = 0;
SetFlag<HardEvent::MTE3_MTE2>(0);
WaitFlag<HardEvent::MTE3_MTE2>(0);
}
}
}
if (sparseM > 0) {
SetFlag<HardEvent::MTE2_MTE3>(0);
WaitFlag<HardEvent::MTE2_MTE3>(0);
DataCopy(gradOutWsp_[totalSparseM * outChannels_], tmpGradOutLocal_, sparseM * outChannels_);
DataCopy(featureWsp_[totalSparseM * inChannels_], tmpFeaturesLocal_, sparseM * inChannels_);
totalSparseM += sparseM;
}
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::GradFeaturesScatterAdd(uint32_t m)
{
if (m <= 0) {
return;
}
SetFlag<HardEvent::V_MTE2>(1);
for (int32_t idxM = 0; idxM < m; idxM += ubMaxTaskNum_) {
uint32_t loopCount = min(ubMaxTaskNum_, m - idxM);
WaitFlag<HardEvent::V_MTE2>(1);
DataCopy(tmpGradFeaturesLocal_, tempGradFeatureWsp_[idxM * inChannels_], loopCount * inChannels_);
SetFlag<HardEvent::MTE2_V>(1);
WaitFlag<HardEvent::MTE2_V>(1);
Simt::VF_CALL<ScatterAddSimt>(
Simt::Dim3 {inChannels_, THREAD_NUM / inChannels_},
(__gm__ float*)featuresGradGM_.GetPhyAddr(),
(__ubuf__ float*)tmpGradFeaturesLocal_.GetPhyAddr(),
(__ubuf__ int32_t*)inputIdxPtrLocal_[idxM].GetPhyAddr(),
loopCount,
inChannels_);
SetFlag<HardEvent::V_MTE2>(1);
}
WaitFlag<HardEvent::V_MTE2>(1);
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::CalGradFeaturesMatmul(uint8_t k, int32_t sparseM, uint8_t subBlockIdx)
{
if (sparseM <= 0) {
return;
}
featureMatmul_.SetOrgShape(sparseM, inChannels_, outChannels_);
featureMatmul_.SetTensorA(gradOutWsp_[featureWspOffset_ * subBlockIdx]);
featureMatmul_.SetTensorB(weightGM_[k * inChannels_ * outChannels_], true);
featureMatmul_.SetSingleShape(sparseM, inChannels_, outChannels_);
featureMatmul_.template IterateAll<false>(tempGradFeatureWsp_[featureWspOffset_ * subBlockIdx], 0);
featureMatmul_.End();
}
template<typename T>
__aicore__ inline void SparseConv3dGrad<T>::CalGradWeightMatmul(uint8_t k, int32_t sparseM, uint8_t subBlockIdx)
{
if (sparseM <= 0) {
return;
}
weightMatmul_.SetOrgShape(inChannels_, outChannels_, sparseM);
weightMatmul_.SetTensorA(featureWsp_[featureWspOffset_ * subBlockIdx], true);
weightMatmul_.SetTensorB(gradOutWsp_[featureWspOffset_ * subBlockIdx]);
weightMatmul_.SetSingleShape(inChannels_, outChannels_, sparseM);
weightMatmul_.template IterateAll<false>(weightGradGM_[k * inChannels_ * outChannels_], 1);
weightMatmul_.End();
}
extern "C" __global__ __aicore__ void sparse_conv3d_grad(GM_ADDR features, GM_ADDR weight, GM_ADDR grad_out_features,
GM_ADDR former_sorted_indices, GM_ADDR indices_offset, GM_ADDR features_grad, GM_ADDR weight_grad,
GM_ADDR workspace, GM_ADDR tiling)
{
GET_TILING_DATA(tiling_data, tiling);
GM_ADDR usrWorkspace = GetUserWorkspace(workspace);
if (usrWorkspace == nullptr) {
return;
}
TPipe pipe;
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2);
SparseConv3dGrad<DTYPE_FEATURES> op;
op.featureMatmul_.SetSubBlockIdx(0);
op.featureMatmul_.Init(&tiling_data.featureMatmulTilingData, &pipe);
op.weightMatmul_.SetSubBlockIdx(0);
op.weightMatmul_.Init(&tiling_data.weightMatmulTilingData, &pipe);
op.Init(&pipe, features, weight, grad_out_features, former_sorted_indices, indices_offset, features_grad,
weight_grad, usrWorkspace, &tiling_data);
op.Process();
}