#include "common/op_host/common.h"
#include "ge/utils.h"
#include "register/op_def_registry.h"
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"
#include "sparse_matmul_tiling.h"
using namespace ge;
using namespace matmul_tiling;
namespace {
constexpr float AVALIABLE_UB_RATIO = 0.8;
constexpr int32_t FLOAT_BYTE_SIZE = 4;
constexpr int32_t INT32_BYTE_SIZE = 4;
constexpr int32_t HALF_BYTE_SIZE = 2;
constexpr int32_t BYTE_ALIGN_SIZE = 32;
constexpr float STAGE2_UB_RATIO = 0.2;
constexpr int32_t MAX_MATMUL_TASK_PER_ITER = 256;
};
namespace optiling {
ge::graphStatus TilingForSparseMatmul(gert::TilingContext* context)
{
SparseMatmulTilingData tiling;
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto platformInfoptr = context->GetPlatformInfo();
if (platformInfoptr == nullptr) {
return ge::GRAPH_FAILED;
}
auto ascendplatformInfo = platform_ascendc::PlatformAscendC(platformInfoptr);
uint64_t ubSize;
ascendplatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
ubSize *= AVALIABLE_UB_RATIO;
auto aivNum = ascendplatformInfo.GetCoreNumAiv();
auto aicNum = ascendplatformInfo.GetCoreNumAic();
context->SetBlockDim(aicNum);
auto inputFeaturePtr = context->GetInputTensor(0);
auto weightPtr = context->GetInputTensor(1);
auto indicesOffsetPtr = context->GetInputTensor(2);
auto featureDataTypePtr = context->GetInputDesc(0);
if (indicesOffsetPtr == nullptr || weightPtr == nullptr || featureDataTypePtr == nullptr || inputFeaturePtr == nullptr) {
return ge::GRAPH_FAILED;
}
if (aivNum == 0 || aicNum == 0) {
return ge::GRAPH_FAILED;
}
auto inputFeatureShape = inputFeaturePtr->GetStorageShape();
auto weightShape = weightPtr->GetStorageShape();
auto indicesOffsetShape = indicesOffsetPtr->GetStorageShape();
auto featureDataType = featureDataTypePtr->GetDataType();
int32_t byteSizePerElements = featureDataType == ge::DT_FLOAT16? HALF_BYTE_SIZE : FLOAT_BYTE_SIZE;
int32_t k0 = weightShape.GetDim(0);
int32_t k1 = weightShape.GetDim(1);
int32_t k2 = weightShape.GetDim(2);
int32_t inChannels = weightShape.GetDim(3);
int32_t outChannels = weightShape.GetDim(4);
int32_t kernelSize = k0 * k1 * k2;
int32_t inChannelsAligned = CeilAlign(inChannels, BYTE_ALIGN_SIZE / byteSizePerElements);
int32_t outChannelsAligned = CeilAlign(outChannels, BYTE_ALIGN_SIZE / byteSizePerElements);
int32_t featureChannelsSize = inChannelsAligned > outChannelsAligned? inChannelsAligned : outChannelsAligned;
int32_t kernelSizeAligned = CeilAlign(kernelSize, BYTE_ALIGN_SIZE / byteSizePerElements);
int32_t outputTaskCount = indicesOffsetShape.GetDim(0) - 1;
int32_t outputCoreTaskCount = outputTaskCount / aivNum;
int32_t outputBigCoreCount = outputTaskCount % aivNum;
int32_t outputSingleLoopTask = (ubSize - k2 * inChannelsAligned * byteSizePerElements) / ((1 + kernelSizeAligned) * INT32_BYTE_SIZE + inChannelsAligned * byteSizePerElements);
int32_t featureBufLen = outputSingleLoopTask;
int32_t matmulTaskPerIter = (outputCoreTaskCount + 1) > MAX_MATMUL_TASK_PER_ITER? MAX_MATMUL_TASK_PER_ITER : (outputCoreTaskCount + 1);
matmulTaskPerIter = matmulTaskPerIter == 0? 1 : matmulTaskPerIter;
auto dataType = (byteSizePerElements == FLOAT_BYTE_SIZE) ? matmul_tiling::DataType::DT_FLOAT : matmul_tiling::DataType::DT_FLOAT16;
matmul_tiling::MatmulApiTiling mm0Tiling(ascendplatformInfo);
mm0Tiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType);
mm0Tiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType);
mm0Tiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType);
mm0Tiling.SetOrgShape(matmulTaskPerIter, outChannels, inChannels * kernelSize);
mm0Tiling.SetShape(matmulTaskPerIter, outChannels, inChannels * kernelSize);
mm0Tiling.SetBias(false);
mm0Tiling.SetBufferSpace(-1, -1, -1);
if (mm0Tiling.GetTiling(tiling.mm0TilingData) == -1) {
return ge::GRAPH_FAILED;
}
tiling.set_k0(k0);
tiling.set_k1(k1);
tiling.set_k2(k2);
tiling.set_inChannels(inChannels);
tiling.set_outChannels(outChannels);
tiling.set_outputCoreTaskCount(outputCoreTaskCount);
tiling.set_outputBigCoreCount(outputBigCoreCount);
tiling.set_outputSingleLoopTask(outputSingleLoopTask);
tiling.set_outputTaskCount(outputTaskCount);
tiling.set_matmulTaskPerIter(matmulTaskPerIter);
tiling.set_availableUBSize(ubSize);
tiling.set_aivNum(aivNum);
tiling.set_featureBufLen(featureBufLen);
ADD_TILING_DATA(context, tiling);
size_t systemWorkspaceSize = ascendplatformInfo.GetLibApiWorkSpaceSize();
size_t usrWorkSpaceSize = static_cast<uint64_t>(aivNum) * matmulTaskPerIter * kernelSize * inChannels * byteSizePerElements;
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
CHECK_NULLPTR(currentWorkspace);
currentWorkspace[0] = systemWorkspaceSize + usrWorkSpaceSize;
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForSparseMatmul(gert::InferShapeContext* context)
{
auto weightShape = context->GetInputShape(1);
auto uniqueIndicesOffsetShape = context->GetInputShape(2);
if (uniqueIndicesOffsetShape == nullptr || weightShape == nullptr) {
return ge::GRAPH_FAILED;
}
gert::Shape* sparseValueShape = context->GetOutputShape(0);
gert::Shape* sparseIndicesShape = context->GetOutputShape(1);
if (sparseValueShape == nullptr || sparseIndicesShape == nullptr) {
return ge::GRAPH_FAILED;
}
uint64_t actualNum = uniqueIndicesOffsetShape->GetDim(0) - 1;
*sparseValueShape = {actualNum, weightShape->GetDim(3)};
*sparseIndicesShape = {actualNum, 4};
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDtypeForSparseMatmul(gert::InferDataTypeContext* context)
{
const ge::DataType feature_dtype = context->GetInputDataType(0);
const ge::DataType indices_dtype = context->GetInputDataType(2);
context->SetOutputDataType(0, feature_dtype);
context->SetOutputDataType(1, indices_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class SparseMatmul : public OpDef {
public:
explicit SparseMatmul(const char* name) : OpDef(name)
{
this->Input("features")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("weight")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("unique_indices_offset")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("former_sorted_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Output("sparse_value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("sparse_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->SetInferShape(ge::InferShapeForSparseMatmul).SetInferDataType(ge::InferDtypeForSparseMatmul);
this->AICore().SetTiling(optiling::TilingForSparseMatmul);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(SparseMatmul);
}