* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include "scatter_mean_grad_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"
using namespace ge;
using namespace std;
using namespace AscendC;
namespace optiling {
constexpr uint32_t WORKSPACE_16MBYTE_SIZE = 16 * 1024 * 1024;
constexpr int64_t DATA_SMALL_MODE = 1;
constexpr int64_t NOT_BROAD_LINE_MODE = 2;
constexpr int64_t DATA_LARGE_MODE = 3;
constexpr uint32_t BLOCK_BYTES = 32;
constexpr uint64_t RESERVE_SAPCE = 4 * 1024;
constexpr uint32_t FLOAT_DTYPE_BYTES = 4;
constexpr uint64_t MAX_OUT_LINE = 16000;
constexpr uint64_t MAX_DEAL_NUM = 2048;
constexpr uint64_t LEAST_LINE_EACH_TASK = 4;
constexpr uint64_t BUFFER_NUM = 2;
constexpr uint64_t INDICES_ONCE_DATANUM = 2048;
constexpr uint64_t GRADOUT_UB_NUM = 2;
constexpr uint64_t INDICES_UB_NUM = 2;
class ScatterMeanGradTiling {
public:
explicit ScatterMeanGradTiling(gert::TilingContext *context) : context(context) {};
ge::graphStatus Init();
ge::graphStatus RunKernelTiling();
void TilingDataPrint();
private:
void SetModeNoTail(int32_t gradDims, int32_t indexDims, uint32_t coreNum);
void SetModeLine(int32_t gradDims, int32_t indexDims, uint32_t coreNum);
void SetHeadNumForTask(uint64_t headMaxTask, uint32_t coreNum);
void SetUbSize(uint64_t headIndicesSize);
ScatterMeanGradTilingData TilingData;
gert::TilingContext *context = nullptr;
uint64_t paramsPre = 1;
uint64_t dimRange = 1;
uint64_t dimRangeOut = 1;
uint64_t paramsPro = 1;
uint64_t tail = 1;
int32_t dim = 0;
uint32_t coreUsed = 1;
uint64_t ubSize = 192 * 1024;
uint64_t gradInUbSize = 1;
uint64_t indexUbSize = 1;
uint64_t gradOutUbSize = 1;
uint64_t indexSumUbSize = 1;
uint64_t gradInNum = 1;
uint64_t indexNum = 1;
uint64_t gradOutNum = 1;
uint64_t countNum = 1;
uint64_t headTaskSmall = 1;
uint64_t taskNumSmall = 1;
uint64_t headLastTaskSmall = 1;
uint64_t headTaskBig = 1;
uint64_t taskNumBig = 1;
uint64_t headLastTaskBig = 1;
uint64_t bigCoreNum = 1;
uint64_t taskEachHead = 1;
uint64_t tilingMode = 0;
uint32_t gradDsize = 4;
uint32_t paramsNumPerBlock = 1;
uint32_t indexNumPerBlock = 8;
uint32_t indexDsize = 4;
};
uint64_t CeilValue(uint64_t a, uint64_t b) {
if (b == 0) {
return 0;
}
return ((a - 1) / b + 1) * b;
}
static uint64_t GetCeilInt(uint64_t value1, uint64_t value2) {
if (value2 == 0) {
return value1;
}
return (value1 + value2 - 1) / value2;
}
static void ComputeTaskNumForeLine(
uint64_t ubOutNum, uint64_t outLineEachCore, uint64_t *taskNum, uint64_t *taskEachLine, uint64_t *taskLastLine) {
if (outLineEachCore <= ubOutNum) {
*taskNum = 1;
*taskEachLine = outLineEachCore;
*taskLastLine = outLineEachCore;
} else {
uint64_t taskNumTemp = GetCeilInt(outLineEachCore, ubOutNum);
*taskNum = taskNumTemp;
*taskEachLine = ubOutNum;
*taskLastLine = outLineEachCore - ubOutNum * (taskNumTemp - 1);
}
}
void ScatterMeanGradTiling::SetUbSize(uint64_t headIndicesSize) {
if (headIndicesSize > INDICES_ONCE_DATANUM) {
indexUbSize = INDICES_ONCE_DATANUM;
} else {
indexUbSize = CeilValue(headIndicesSize, indexNumPerBlock);
}
uint64_t ubAvailableSize = ubSize - indexUbSize * indexDsize * INDICES_UB_NUM;
gradOutUbSize = CeilValue(ubAvailableSize / GRADOUT_UB_NUM / gradDsize, paramsNumPerBlock);
return;
}
void ScatterMeanGradTiling::SetHeadNumForTask(uint64_t headMaxTask, uint32_t coreNum) {
uint64_t headBigCore = GetCeilInt(paramsPre, coreNum);
uint64_t headSmallCore = headBigCore - 1;
bigCoreNum = paramsPre - headSmallCore * coreNum;
headTaskSmall = std::min(headMaxTask, headSmallCore);
taskNumSmall = GetCeilInt(headSmallCore, headTaskSmall);
headLastTaskSmall = headSmallCore - (taskNumSmall - 1) * headTaskSmall;
headTaskBig = std::min(headMaxTask, headBigCore);
taskNumBig = GetCeilInt(headBigCore, headTaskBig);
headLastTaskBig = headBigCore - (taskNumBig - 1) * headTaskBig;
}
void ScatterMeanGradTiling::SetModeNoTail(int32_t gradDims, int32_t indexDims, uint32_t coreNum) {
if (coreNum == 0 || gradDims == 0 || indexDims == 0) {
return;
}
uint64_t headOutSize = dimRangeOut * paramsPro;
uint64_t headIndicesSize = dimRange * paramsPro;
uint64_t ubBytesforHead = headOutSize * gradDsize * GRADOUT_UB_NUM + headIndicesSize * indexDsize * INDICES_UB_NUM;
if (ubBytesforHead < ubSize) {
context->SetTilingKey(DATA_SMALL_MODE);
tilingMode = 0;
auto headMaxTask = ubSize / ubBytesforHead;
SetHeadNumForTask(headMaxTask, coreNum);
gradOutUbSize = headTaskBig * headOutSize;
indexUbSize = headTaskBig * headIndicesSize;
} else {
SetUbSize(headIndicesSize);
if (gradOutUbSize > headOutSize) {
context->SetTilingKey(DATA_SMALL_MODE);
tilingMode = 1;
auto headMaxTask = gradOutUbSize / headOutSize;
SetHeadNumForTask(headMaxTask, coreNum);
gradOutUbSize = headTaskBig * headOutSize;
indexUbSize = std::min(
(ubSize - gradOutUbSize * GRADOUT_UB_NUM * gradDsize) / INDICES_UB_NUM / indexDsize, headIndicesSize);
indexUbSize = CeilValue(indexUbSize, indexNumPerBlock);
} else {
context->SetTilingKey(DATA_LARGE_MODE);
tilingMode = DATA_LARGE_MODE;
taskEachHead = GetCeilInt(headOutSize, gradOutUbSize);
auto taskNum = paramsPre * taskEachHead;
taskNumSmall = taskNum / coreNum;
taskNumBig = taskNumSmall + 1;
bigCoreNum = taskNum - taskNumSmall * coreNum;
}
}
coreUsed = taskNumSmall == 0 ? bigCoreNum : coreNum;
TilingData.set_headTaskSmall(headTaskSmall);
TilingData.set_taskNumSmall(taskNumSmall);
TilingData.set_headLastTaskSmall(headLastTaskSmall);
TilingData.set_headTaskBig(headTaskBig);
TilingData.set_taskNumBig(taskNumBig);
TilingData.set_headLastTaskBig(headLastTaskBig);
TilingData.set_taskEachHead(taskEachHead);
}
void ScatterMeanGradTiling::SetModeLine(int32_t gradDims, int32_t indexDims, uint32_t coreNum) {
if (coreNum == 0 || gradDims == 0 || indexDims == 0) {
return;
}
context->SetTilingKey(NOT_BROAD_LINE_MODE);
uint64_t dataLineSmallCore = 1;
uint64_t dataLineBigCore = 1;
auto body = paramsPro / tail;
uint64_t dataLine = dimRange * paramsPre * (paramsPro / tail);
if (dataLine <= 2 * LEAST_LINE_EACH_TASK) {
coreUsed = 1;
dataLineBigCore = dataLine;
bigCoreNum = 1;
} else {
dataLineBigCore = std::max(GetCeilInt(dataLine, coreNum), LEAST_LINE_EACH_TASK);
dataLineSmallCore = dataLineBigCore - 1;
coreUsed = GetCeilInt(dataLine, dataLineBigCore);
bigCoreNum = dataLine - dataLineSmallCore * coreUsed;
}
uint64_t ubTailNum;
uint64_t taskNum;
uint64_t taskEachLine;
uint64_t taskLastLine;
if (tail % paramsNumPerBlock == 0 && tail < MAX_DEAL_NUM / gradDsize) {
tilingMode = 1;
ubTailNum = tail;
auto availSize = ubSize - ubTailNum * BUFFER_NUM * gradDsize;
taskEachLine = std::min(dataLineBigCore, availSize / gradDsize / (tail + 1));
taskNum = GetCeilInt(dataLineBigCore, taskEachLine);
taskLastLine = dataLineBigCore - taskEachLine * (taskNum - 1);
indexUbSize = CeilValue(taskEachLine, paramsNumPerBlock);
gradInUbSize = tail * taskEachLine;
gradOutUbSize = ubTailNum;
} else {
tilingMode = 0;
auto availIndicesSize =
ubSize - std::min(CeilValue(tail, paramsNumPerBlock), MAX_DEAL_NUM) * BUFFER_NUM * gradDsize;
ComputeTaskNumForeLine(availIndicesSize / indexDsize, dataLineBigCore, &taskNum, &taskEachLine, &taskLastLine);
indexUbSize = std::min(availIndicesSize / indexDsize, dataLineBigCore);
ubTailNum = (ubSize - indexUbSize * indexDsize) / BUFFER_NUM / BLOCK_BYTES * indexNumPerBlock;
ubTailNum = std::min(ubTailNum, CeilValue(tail, paramsNumPerBlock));
gradInUbSize = 0;
gradOutUbSize = ubTailNum * BUFFER_NUM;
}
TilingData.set_taskNum(taskNum);
TilingData.set_taskEachLine(taskEachLine);
TilingData.set_taskLastLine(taskLastLine);
TilingData.set_ubTailNum(ubTailNum);
TilingData.set_bacthSmallCore(dataLineSmallCore);
TilingData.set_gradInUbSize(gradInUbSize);
TilingData.set_gradOutUbSize(gradOutUbSize);
}
ge::graphStatus ScatterMeanGradTiling::Init() {
auto platformInfo = context->GetPlatformInfo();
if (platformInfo == nullptr) {
return ge::GRAPH_FAILED;
}
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
uint32_t coreNum = ascendcPlatform.GetCoreNumAiv();
if (coreNum == 0) {
return ge::GRAPH_FAILED;
}
uint64_t totalUbSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, totalUbSize);
ubSize = totalUbSize - RESERVE_SAPCE;
if (context->GetInputShape(0) == nullptr || context->GetInputShape(1) == nullptr ||
context->GetInputShape(2) == nullptr) {
return ge::GRAPH_FAILED;
}
if (context->GetOutputShape(0) == nullptr) {
return ge::GRAPH_FAILED;
}
auto gradOutShape = context->GetInputShape(0)->GetStorageShape();
auto indexShape = context->GetInputShape(1)->GetStorageShape();
auto countShape = context->GetInputShape(2)->GetStorageShape();
auto gradInShape = context->GetOutputShape(0)->GetStorageShape();
auto attrs = context->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
const int64_t *axisPtr = attrs->GetAttrPointer<int64_t>(0);
int32_t axis = static_cast<int32_t>(*axisPtr);
int32_t gradDims = gradOutShape.GetDimNum();
int32_t indexDims = indexShape.GetDimNum();
indexDims = indexDims == 0 ? 1 : indexDims;
axis = (axis + indexDims) % indexDims;
dim = axis;
for (int32_t i = 0; i < axis; i++) {
paramsPre *= gradInShape.GetDim(i);
}
dimRange = gradInShape.GetDim(axis);
dimRangeOut = gradOutShape.GetDim(axis);
if (dimRange == 0 || dimRangeOut == 0 || paramsPre == 0) {
return ge::GRAPH_FAILED;
}
for (int32_t i = axis + 1; i < gradDims; i++) {
paramsPro *= gradInShape.GetDim(i);
}
for (int32_t i = indexDims; i < gradDims; i++) {
tail *= gradInShape.GetDim(i);
}
gradInNum = paramsPre * dimRange * paramsPro;
for (int32_t i = 0; i < indexDims; i++) {
indexNum *= indexShape.GetDim(i);
}
gradOutNum = paramsPre * dimRangeOut * paramsPro;
countNum = countShape.GetShapeSize();
if (context->GetInputDesc(0) == nullptr || context->GetInputDesc(1) == nullptr) {
return ge::GRAPH_FAILED;
}
auto gradDtype = context->GetInputDesc(0)->GetDataType();
gradDsize = sizeof(gradDtype);
paramsNumPerBlock = BLOCK_BYTES / gradDsize;
auto indexDtype = context->GetInputDesc(1)->GetDataType();
indexDsize = sizeof(indexDtype);
if (tail == 1) {
SetModeNoTail(gradDims, indexDims, coreNum);
} else {
SetModeLine(gradDims, indexDims, coreNum);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ScatterMeanGradTiling::RunKernelTiling() {
context->SetBlockDim(coreUsed);
TilingData.set_tilingMode(tilingMode);
TilingData.set_dimRange(dimRange);
TilingData.set_dimRangeOut(dimRangeOut);
TilingData.set_paramsPro(paramsPro);
TilingData.set_gradInUbSize(gradInUbSize);
TilingData.set_indexUbSize(indexUbSize);
TilingData.set_gradOutUbSize(gradOutUbSize);
TilingData.set_indexSumUbSize(indexSumUbSize);
TilingData.set_gradInNum(gradInNum);
TilingData.set_indexNum(indexNum);
TilingData.set_gradOutNum(gradOutNum);
TilingData.set_countNum(countNum);
TilingData.set_bigCoreNum(bigCoreNum);
TilingData.set_tail(tail);
size_t sysWorkspaceSize = WORKSPACE_16MBYTE_SIZE;
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
if (currentWorkspace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkspace[0] = sysWorkspaceSize;
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
TilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(TilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingFunc4ScatterMeanGrad(gert::TilingContext *context) {
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
ScatterMeanGradTiling tilingObject(context);
tilingObject.Init();
return tilingObject.RunKernelTiling();
}
}
namespace ge {
static ge::graphStatus InferShape4ScatterMeanGrad(gert::InferShapeContext *context) {
const gert::Shape *gradOutShape = context->GetInputShape(0);
if (gradOutShape == nullptr) {
return ge::GRAPH_FAILED;
}
const gert::Shape *indexShape = context->GetInputShape(1);
if (indexShape == nullptr) {
return ge::GRAPH_FAILED;
}
auto attrs = context->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
int32_t axis = *(attrs->GetAttrPointer<int32_t>(0));
int32_t gradDims = gradOutShape->GetDimNum();
gradDims = gradDims == 0 ? 1 : gradDims;
axis = (axis + gradDims) % gradDims;
gert::Shape *gradInShape = context->GetOutputShape(0);
if (gradInShape == nullptr) {
return ge::GRAPH_FAILED;
}
*gradInShape = *gradOutShape;
gradInShape->SetDim(axis, indexShape->GetDim(axis));
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDtype4ScatterMeanGrad(gert::InferDataTypeContext *context) {
const ge::DataType grad_out_dtype = context->GetInputDataType(0);
context->SetOutputDataType(0, grad_out_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class ScatterMeanGrad : public OpDef {
public:
explicit ScatterMeanGrad(const char *name) : OpDef(name) {
this->Input("grad_out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("count")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("grad_in")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("dim").Int();
this->SetInferShape(ge::InferShape4ScatterMeanGrad).SetInferDataType(ge::InferDtype4ScatterMeanGrad);
this->AICore().SetTiling(optiling::TilingFunc4ScatterMeanGrad);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(ScatterMeanGrad);
}