#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "kernel_utils.h"
using namespace AscendC;
using namespace std;
constexpr int32_t ALIGN_NUM = 8;
constexpr int32_t FLOAT_SIZE = 4;
constexpr int32_t DOUBLE_NUM = 2;
constexpr int32_t BUFFER_NUM = 5;
constexpr int32_t ONE_VALUE = 1;
constexpr int32_t ZERO_VALUE = 0;
constexpr int32_t SRC_SHAPE_DIM = 8;
constexpr float ZERO_FLOAT_VALUE = 0.0f;
constexpr float ONE_FLOAT_VALUE = 1.0f;
constexpr float UB_RATIO = 0.8f;
class GraphSoftmaxGrad {
public:
__aicore__ inline GraphSoftmaxGrad() {}
__aicore__ inline void Init(GM_ADDR index, GM_ADDR softmax_output, GM_ADDR grad_output, GM_ADDR reduce_sum,
GM_ADDR src_grad, const GraphSoftmaxGradTilingData *tiling_data, TPipe *pipe) {
ASSERT(GetBlockNum() != 0 && "Block Dim can not be Zero!");
this->blockIndex = GetBlockIdx();
this->_pipe = pipe;
GetTilingData(tiling_data);
uint64_t bufferSize = taskNumPerLoop * SRC_SHAPE_DIM * FLOAT_SIZE;
uint64_t castEdgeNum = static_cast<uint64_t>(edgeNum);
uint64_t gmSize = castEdgeNum * SRC_SHAPE_DIM;
indexGM.SetGlobalBuffer((__gm__ DTYPE_INDEX *)index, castEdgeNum);
softmaxGM.SetGlobalBuffer((__gm__ DTYPE_SOFTMAX_OUTPUT *)softmax_output, gmSize);
gradGM.SetGlobalBuffer((__gm__ DTYPE_GRAD_OUTPUT *)grad_output, gmSize);
srcgradGM.SetGlobalBuffer((__gm__ DTYPE_SRC_GRAD *)src_grad, gmSize);
reducesumGM.SetGlobalBuffer((__gm__ DTYPE_REDUCE_SUM *)reduce_sum, nodeNum * SRC_SHAPE_DIM);
this->_pipe->InitBuffer(SoftmaxOutputBuffer, bufferSize);
this->_pipe->InitBuffer(IndexTensorBuffer, taskNumPerLoop * FLOAT_SIZE);
this->_pipe->InitBuffer(GradOutputBuffer, bufferSize);
this->_pipe->InitBuffer(SrcGradTensorBuffer, bufferSize);
this->_pipe->InitBuffer(TmpIndexSelectBuffer, bufferSize);
}
__aicore__ inline void Process() {
AllocLocalTensors();
if (taskLoop == 1) {
SingleLoopComputing();
} else {
MultiLoopComputing();
}
}
private:
__aicore__ inline void GetTilingData(const GraphSoftmaxGradTilingData *tiling_data) {
edgeNum = tiling_data->edgeNum;
alignTaskNum = tiling_data->alignTaskNum;
tailNum = tiling_data->tailNum;
nodeNum = tiling_data->nodeNum;
taskNumPerLoop = tiling_data->taskNumPerLoop;
taskLoop = tiling_data->taskLoop;
blockDim = tiling_data->blockDim;
tailCoreNum = tiling_data->tailCoreNum;
taskNumPerCore = tiling_data->taskNumPerCore;
ubTotalSize = tiling_data->ubTotalSize;
}
__aicore__ inline uint64_t GetCopyIndex(uint32_t taskLoopIndex) {
uint64_t copyIndex = blockIndex * taskNumPerCore + taskLoopIndex * taskNumPerLoop;
if (blockIndex >= tailCoreNum) {
copyIndex -= (blockIndex - tailCoreNum);
}
return copyIndex;
}
__aicore__ inline int32_t GetTaskNum(uint32_t taskLoopIndex) {
int32_t taskNumPerCurLoop;
if (taskLoopIndex == taskLoop - 1) {
taskNumPerCurLoop = taskNumPerCore - taskLoopIndex * taskNumPerLoop;
if (blockIndex >= tailCoreNum) {
taskNumPerCurLoop -= 1;
}
} else {
taskNumPerCurLoop = taskNumPerLoop;
}
return taskNumPerCurLoop;
}
__aicore__ inline void SingleLoopComputing() {
uint64_t copyIndex = GetCopyIndex(ZERO_VALUE);
int32_t taskNumPerCurLoop = GetTaskNum(ZERO_VALUE);
CopyIn(copyIndex, taskNumPerCurLoop);
ReduceSum(taskNumPerCurLoop);
SyncAll();
SingleLoopIndexSelect(taskNumPerCurLoop);
CopyOut(copyIndex, taskNumPerCurLoop);
}
__aicore__ inline void MultiLoopComputing() {
for (uint32_t taskLoopIndex = 0; taskLoopIndex < taskLoop; taskLoopIndex++) {
uint64_t copyIndex = GetCopyIndex(taskLoopIndex);
int32_t taskNumPerCurLoop = GetTaskNum(taskLoopIndex);
CopyIn(copyIndex, taskNumPerCurLoop);
ReduceSum(taskNumPerCurLoop);
}
SyncAll();
for (uint32_t taskLoopIndex = 0; taskLoopIndex < taskLoop; taskLoopIndex++) {
uint64_t copyIndex = GetCopyIndex(taskLoopIndex);
int32_t taskNumPerCurLoop = GetTaskNum(taskLoopIndex);
int32_t IndexCopyLength = (taskNumPerCurLoop + ALIGN_NUM - 1) / ALIGN_NUM * ALIGN_NUM;
DataCopy(IndexTensor, indexGM[copyIndex], IndexCopyLength);
MultiLoopIndexSelect(copyIndex, taskNumPerCurLoop);
CopyOut(copyIndex, taskNumPerCurLoop);
}
}
__aicore__ inline void AllocLocalTensors() {
IndexTensor = IndexTensorBuffer.Get<DTYPE_INDEX>();
SoftmaxOutputTensor = SoftmaxOutputBuffer.Get<DTYPE_SOFTMAX_OUTPUT>();
GradOutputTensor = GradOutputBuffer.Get<DTYPE_GRAD_OUTPUT>();
SrcGradTensor = SrcGradTensorBuffer.Get<DTYPE_SRC_GRAD>();
TmpIndexSelectTensor = TmpIndexSelectBuffer.Get<DTYPE_SRC_GRAD>();
Duplicate(GradOutputTensor, ZERO_FLOAT_VALUE, taskNumPerLoop * SRC_SHAPE_DIM);
Duplicate(TmpIndexSelectTensor, ZERO_FLOAT_VALUE, taskNumPerLoop * SRC_SHAPE_DIM);
}
__aicore__ inline void CopyIn(uint64_t copyIndex, int32_t copyLength) {
SetFlag<HardEvent::V_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE2>(EVENT_ID0);
int32_t IndexCopyLength = (copyLength + ALIGN_NUM - 1) / ALIGN_NUM * ALIGN_NUM;
DataCopy(IndexTensor, indexGM[copyIndex], IndexCopyLength);
DataCopy(SoftmaxOutputTensor, softmaxGM[copyIndex * SRC_SHAPE_DIM], copyLength * SRC_SHAPE_DIM);
DataCopy(GradOutputTensor, gradGM[copyIndex * SRC_SHAPE_DIM], copyLength * SRC_SHAPE_DIM);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
Mul(GradOutputTensor, SoftmaxOutputTensor, GradOutputTensor, GradOutputTensor.GetSize());
}
__aicore__ inline void ReduceSum(int32_t egTotalNum) {
SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);
SetAtomicAdd<float>();
for (int32_t egIndex = 0; egIndex < egTotalNum; egIndex++) {
int32_t group = IndexTensor.GetValue(egIndex);
DataCopy(reducesumGM[group * SRC_SHAPE_DIM], GradOutputTensor[egIndex * SRC_SHAPE_DIM], SRC_SHAPE_DIM);
}
SetAtomicNone();
SetFlag<HardEvent::MTE3_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_V>(EVENT_ID0);
}
__aicore__ inline void SingleLoopIndexSelect(int32_t egTotalNum) {
SetFlag<HardEvent::V_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE2>(EVENT_ID0);
for (int32_t egIndex = 0; egIndex < egTotalNum; egIndex++) {
int32_t group = IndexTensor.GetValue(egIndex);
DataCopy(TmpIndexSelectTensor[egIndex * SRC_SHAPE_DIM], reducesumGM[group * SRC_SHAPE_DIM], SRC_SHAPE_DIM);
}
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
Mul(TmpIndexSelectTensor, TmpIndexSelectTensor, SoftmaxOutputTensor, SoftmaxOutputTensor.GetSize());
Sub(SrcGradTensor, GradOutputTensor, TmpIndexSelectTensor, TmpIndexSelectTensor.GetSize());
}
__aicore__ inline void MultiLoopIndexSelect(uint64_t copyIndex, int32_t egTotalNum) {
SetFlag<HardEvent::V_MTE2>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE2>(EVENT_ID0);
for (int32_t egIndex = 0; egIndex < egTotalNum; egIndex++) {
int32_t group = IndexTensor.GetValue(egIndex);
DataCopy(TmpIndexSelectTensor[egIndex * SRC_SHAPE_DIM], reducesumGM[group * SRC_SHAPE_DIM], SRC_SHAPE_DIM);
}
DataCopy(SoftmaxOutputTensor, softmaxGM[copyIndex * SRC_SHAPE_DIM], egTotalNum * SRC_SHAPE_DIM);
DataCopy(GradOutputTensor, gradGM[copyIndex * SRC_SHAPE_DIM], egTotalNum * SRC_SHAPE_DIM);
SetFlag<HardEvent::MTE2_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE2_V>(EVENT_ID0);
Mul(GradOutputTensor, SoftmaxOutputTensor, GradOutputTensor, GradOutputTensor.GetSize());
Mul(TmpIndexSelectTensor, TmpIndexSelectTensor, SoftmaxOutputTensor, SoftmaxOutputTensor.GetSize());
Sub(SrcGradTensor, GradOutputTensor, TmpIndexSelectTensor, TmpIndexSelectTensor.GetSize());
}
__aicore__ inline void CopyOut(uint64_t copyIndex, int32_t copyLength) {
SetFlag<HardEvent::V_MTE3>(EVENT_ID0);
WaitFlag<HardEvent::V_MTE3>(EVENT_ID0);
DataCopy(srcgradGM[copyIndex * SRC_SHAPE_DIM], SrcGradTensor, copyLength * SRC_SHAPE_DIM);
SetFlag<HardEvent::MTE3_V>(EVENT_ID0);
WaitFlag<HardEvent::MTE3_V>(EVENT_ID0);
}
private:
TPipe *_pipe;
TBuf<TPosition::VECCALC> IndexTensorBuffer, SoftmaxOutputBuffer, GradOutputBuffer;
TBuf<TPosition::VECCALC> SrcGradTensorBuffer;
TBuf<TPosition::VECCALC> TmpIndexSelectBuffer;
LocalTensor<int32_t> IndexTensor;
LocalTensor<float> SoftmaxOutputTensor, GradOutputTensor;
LocalTensor<float> SrcGradTensor;
LocalTensor<float> TmpIndexSelectTensor;
GlobalTensor<DTYPE_INDEX> indexGM;
GlobalTensor<DTYPE_SOFTMAX_OUTPUT> softmaxGM;
GlobalTensor<DTYPE_GRAD_OUTPUT> gradGM;
GlobalTensor<DTYPE_SRC_GRAD> srcgradGM;
GlobalTensor<DTYPE_REDUCE_SUM> reducesumGM;
int32_t edgeNum, nodeNum, alignTaskNum, tailNum, taskStartIndex, taskLoop, taskNumPerLoop, tailCoreNum;
uint32_t blockDim, taskNumPerCore;
uint64_t blockIndex, ubTotalSize;
};
extern "C" __global__ __aicore__ void graph_softmax_grad(GM_ADDR index, GM_ADDR softmax_output, GM_ADDR grad_output,
GM_ADDR reduce_sum, GM_ADDR src_grad, GM_ADDR workspace, GM_ADDR tiling) {
#if __CCE_AICORE__ == 310
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
#endif
TPipe pipe;
GET_TILING_DATA(tiling_data, tiling);
if (TILING_KEY_IS(1)) {
GraphSoftmaxGrad op;
op.Init(index, softmax_output, grad_output, reduce_sum, src_grad, &tiling_data, &pipe);
op.Process();
}
}