#include "graph_softmax_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
namespace {
constexpr float AVALIABLE_UB_RATIO = 0.8;
constexpr uint32_t SRC_IDX = 0;
constexpr uint32_t OUTPUT_IDX = 0;
constexpr uint32_t EDGE_IDX = 0;
constexpr uint32_t FEATURE_IDX = 1;
constexpr uint32_t N_IDX = 0;
constexpr uint32_t SINGLE_LOOP_TASK =
2360;
}
namespace optiling {
static ge::graphStatus TilingForGraphSoftmax(gert::TilingContext *context) {
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
return ge::GRAPH_FAILED;
}
auto ascendplatformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr);
auto aivNum = ascendplatformInfo.GetCoreNumAiv();
context->SetBlockDim(aivNum);
uint64_t ubSize;
ascendplatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
ubSize *= AVALIABLE_UB_RATIO;
if (aivNum == 0 || ubSize == 0) {
return ge::GRAPH_FAILED;
}
const gert::StorageShape *srcShape = context->GetInputShape(SRC_IDX);
const gert::RuntimeAttrs *attr = context->GetAttrs();
if (srcShape == nullptr || attr == nullptr) {
return ge::GRAPH_FAILED;
}
auto NPtr = attr->GetAttrPointer<int>(N_IDX);
if (NPtr == nullptr) {
return ge::GRAPH_FAILED;
}
uint32_t N = *NPtr;
uint32_t numEdge = srcShape->GetStorageShape().GetDim(EDGE_IDX);
uint32_t numFeature = srcShape->GetStorageShape().GetDim(FEATURE_IDX);
uint32_t totalTask = numEdge;
uint32_t coreTask = (totalTask + aivNum - 1) / aivNum;
uint32_t coreWorkspace = (N + aivNum - 1) / aivNum;
uint32_t totalWorkspace = coreWorkspace * aivNum;
uint32_t bigCoreCount = totalTask % aivNum == 0 ? aivNum : totalTask % aivNum;
uint32_t singleLoopTaskCount = SINGLE_LOOP_TASK;
GraphSoftmaxTilingData tilingData;
tilingData.set_coreTask(coreTask);
tilingData.set_coreWorkspace(coreWorkspace);
tilingData.set_totalTask(totalTask);
tilingData.set_totalWorkspace(totalWorkspace);
tilingData.set_bigCoreCount(bigCoreCount);
tilingData.set_singleLoopTaskCount(singleLoopTaskCount);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
size_t systemWorkspaceSize = ascendplatformInfo.GetLibApiWorkSpaceSize();
size_t usrWorkSpaceSize = 2 * totalWorkspace * numFeature * sizeof(float);
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
if (currentWorkspace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkspace[0] = systemWorkspaceSize + usrWorkSpaceSize;
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForGraphSoftmax(gert::InferShapeContext *context) {
const gert::Shape *src = context->GetInputShape(SRC_IDX);
gert::Shape *softmaxResult = context->GetOutputShape(OUTPUT_IDX);
if (src == nullptr || softmaxResult == nullptr) {
return ge::GRAPH_FAILED;
}
uint64_t numEdge = src->GetDim(EDGE_IDX);
uint64_t numFeature = src->GetDim(FEATURE_IDX);
*softmaxResult = {numEdge, numFeature};
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForGraphSoftmax(gert::InferDataTypeContext *context) {
const ge::DataType num_valid_dtype = context->GetInputDataType(SRC_IDX);
context->SetOutputDataType(OUTPUT_IDX, num_valid_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class GraphSoftmax : public OpDef {
public:
explicit GraphSoftmax(const char *name) : OpDef(name) {
this->Input("src")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("softmax_result")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Attr("N").Int();
this->SetInferShape(ge::InferShapeForGraphSoftmax).SetInferDataType(ge::InferDataTypeForGraphSoftmax);
this->AICore().SetTiling(optiling::TilingForGraphSoftmax);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(GraphSoftmax);
}