#include "geometric_kernel_attention_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "ge/utils.h"
namespace {
const int32_t VALUE_PTR_INDEX = 0;
const int32_t SPATIAL_PTR_INDEX = 1;
const int32_t LEVEL_PTR_INDEX = 2;
const int32_t SAMPLING_PTR_INDEX = 3;
const int32_t WEIGHT_PTR_INDEX = 4;
const int32_t OUTPUT_PTR_INDEX = 0;
const int32_t BS_INDEX = 0;
const int32_t HEAD_INDEX = 1;
const int32_t KEY_INDEX = 2;
const int32_t DIM_INDEX = 3;
const int32_t LEVEL_INDEX = 0;
const int32_t QUERY_INDEX = 1;
const int32_t POINT_INDEX = 4;
const int32_t ALIGN_NUM = 8;
}
namespace optiling {
static ge::graphStatus TilingForGeometricKernelAttention(gert::TilingContext* context)
{
GeometricKernelAttentionTilingData tiling;
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto valueTensorPtr = context->GetInputTensor(VALUE_PTR_INDEX);
auto spatialshapeTensorPtr = context->GetInputTensor(SPATIAL_PTR_INDEX);
auto levelstartindexTensorPtr = context->GetInputShape(LEVEL_PTR_INDEX);
auto samplinglocationsTensorPtr = context->GetInputShape(SAMPLING_PTR_INDEX);
auto attentionweightsTensorPtr = context->GetInputShape(WEIGHT_PTR_INDEX);
if (valueTensorPtr == nullptr || spatialshapeTensorPtr == nullptr || levelstartindexTensorPtr == nullptr || samplinglocationsTensorPtr == nullptr || attentionweightsTensorPtr == nullptr) {
return ge::GRAPH_FAILED;
}
auto ValueShape = context->GetInputShape(VALUE_PTR_INDEX);
auto SpatialShape = context->GetInputShape(SPATIAL_PTR_INDEX);
auto SamplingShape = context->GetInputShape(SAMPLING_PTR_INDEX);
if (ValueShape == nullptr || SpatialShape == nullptr || SamplingShape == nullptr || context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
int32_t batchSize = ValueShape->GetStorageShape().GetDim(BS_INDEX);
int32_t numHeads = ValueShape->GetStorageShape().GetDim(HEAD_INDEX);
int32_t numKeys = ValueShape->GetStorageShape().GetDim(KEY_INDEX);
int32_t dim = ValueShape->GetStorageShape().GetDim(DIM_INDEX);
int32_t numLevels = SpatialShape->GetStorageShape().GetDim(LEVEL_INDEX);
int32_t numQueries = SamplingShape->GetStorageShape().GetDim(QUERY_INDEX);
int32_t numPoints = SamplingShape->GetStorageShape().GetDim(POINT_INDEX);
int32_t totalTaskNum = batchSize * numQueries * numHeads;
int32_t alignTaskNum = AlignUp(totalTaskNum, ALIGN_NUM);
int32_t alignLevels = AlignUp(numLevels, ALIGN_NUM);
int32_t alignDim = AlignUp(dim, ALIGN_NUM);
auto platform = context->GetPlatformInfo();
if (platform == nullptr) {
return ge::GRAPH_FAILED;
}
auto platformInfo = platform_ascendc::PlatformAscendC(platform);
uint64_t ubTotalSize;
platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubTotalSize);
uint32_t blockDim = platformInfo.GetCoreNumAiv();
if (blockDim == 0) {
return ge::GRAPH_FAILED;
}
int32_t tailNum = alignTaskNum - totalTaskNum;
uint32_t taskNumPerScore = (alignTaskNum / blockDim / ALIGN_NUM) * ALIGN_NUM;
uint32_t taskNumPerLcore = taskNumPerScore + ALIGN_NUM;
uint32_t scoreNum = (blockDim * (ALIGN_NUM + taskNumPerScore) - alignTaskNum) / ALIGN_NUM;
uint32_t lcoreNum = blockDim - scoreNum;
if (taskNumPerScore == 0) {
blockDim = blockDim - scoreNum;
}
if (taskNumPerLcore == 0) {
blockDim = blockDim - lcoreNum;
}
tiling.set_blockDim(blockDim);
tiling.set_ubTotalSize(ubTotalSize);
tiling.set_batchSize(batchSize);
tiling.set_numKeys(numKeys);
tiling.set_numHeads(numHeads);
tiling.set_numQueries(numQueries);
tiling.set_numLevels(numLevels);
tiling.set_numPoints(numPoints);
tiling.set_dim(dim);
tiling.set_alignLevels(alignLevels);
tiling.set_alignDim(alignDim);
tiling.set_totalTaskNum(totalTaskNum);
tiling.set_alignTaskNum(alignTaskNum);
tiling.set_tailNum(tailNum);
tiling.set_taskNumPerScore(taskNumPerScore);
tiling.set_taskNumPerLcore(taskNumPerLcore);
tiling.set_scoreNum(scoreNum);
tiling.set_lcoreNum(lcoreNum);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
context->SetBlockDim(blockDim);
context->SetTilingKey(1);
size_t systemWorkspaceSize = platformInfo.GetLibApiWorkSpaceSize();
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
if (currentWorkspace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkspace[0] = systemWorkspaceSize;
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeGeometricKernelAttention(gert::InferShapeContext* context)
{
const gert::Shape* ValueShape = context->GetInputShape(VALUE_PTR_INDEX);
const gert::Shape* SpatialShape = context->GetInputShape(SPATIAL_PTR_INDEX);
const gert::Shape* LevelShape = context->GetInputShape(LEVEL_PTR_INDEX);
const gert::Shape* SamplingShape = context->GetInputShape(SAMPLING_PTR_INDEX);
const gert::Shape* AttentionShape = context->GetInputShape(WEIGHT_PTR_INDEX);
gert::Shape* OutputShape = context->GetOutputShape(OUTPUT_PTR_INDEX);
if (ValueShape == nullptr || SpatialShape == nullptr || LevelShape == nullptr || SamplingShape == nullptr || AttentionShape == nullptr || OutputShape == nullptr) {
return ge::GRAPH_FAILED;
}
int32_t batchSize = ValueShape->GetDim(BS_INDEX);
int32_t numHeads = ValueShape->GetDim(HEAD_INDEX);
int32_t dim = ValueShape->GetDim(DIM_INDEX);
int32_t numQueries = SamplingShape->GetDim(QUERY_INDEX);
*OutputShape = {batchSize, numQueries, numHeads * dim};
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeGeometricKernelAttention(gert::InferDataTypeContext* context)
{
const ge::DataType valueDtype = context->GetInputDataType(0);
context->SetOutputDataType(0, valueDtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class GeometricKernelAttention : public OpDef {
public:
explicit GeometricKernelAttention(const char* name) : OpDef(name)
{
this->Input("value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("spatial_shapes")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("level_start_index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("sampling_locations")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("attention_weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->SetInferShape(ge::InferShapeGeometricKernelAttention)
.SetInferDataType(ge::InferDataTypeGeometricKernelAttention);
this->AICore()
.SetTiling(optiling::TilingForGeometricKernelAttention);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(GeometricKernelAttention);
}