* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved.
*/
#include "geometric_kernel_attn_grad_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
#include "ge/utils.h"
using namespace ge;
using namespace std;
namespace {
const uint32_t POS_INPUT_VALUE = 0;
const uint32_t POS_INPUT_SPATIAL_SHAPES = 1;
const uint32_t POS_INPUT_LEVEL_START_INDEX = 2;
const uint32_t POS_INPUT_SAMPLING_LOCATIONS = 3;
const uint32_t POS_INPUT_ATTN_WEIGHTS = 4;
const uint32_t POS_INPUT_GRAD_OUTPUT = 5;
const uint32_t POS_OUTPUT_GRAD_VALUE = 0;
const uint32_t POS_OUTPUT_GRAD_ATTN_WEIGHTS = 1;
const uint32_t VALUE_BATCH_SIZE_DIM = 0;
const uint32_t VALUE_NUM_KEYS_DIM = 1;
const uint32_t VALUE_NUM_HEADS_DIM = 2;
const uint32_t VALUE_EMBED_DIMS_DIM = 3;
const uint32_t ATTN_WEIGHTS_BATCH_SIZE_DIM = 0;
const uint32_t ATTN_WEIGHTS_NUM_QUERIES_DIM = 1;
const uint32_t ATTN_WEIGHTS_NUM_LEVELS_DIM = 3;
const uint32_t ATTN_WEIGHTS_NUM_POINTS_DIM = 4;
const uint64_t UB_RESERVE_BYTES = 10 * 1024;
const uint32_t FLOAT32_BYTES = 4;
const uint32_t BLOCK_BYTES = 32;
}
namespace optiling {
static ge::graphStatus TilingFuncForGeometricKernelAttnGrad(gert::TilingContext* context)
{
GeometricKernelAttnGradTilingData tiling;
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto valueTensorPtr = context->GetInputTensor(POS_INPUT_VALUE);
auto attnWeightsTensorPtr = context->GetInputTensor(POS_INPUT_ATTN_WEIGHTS);
if (valueTensorPtr == nullptr || attnWeightsTensorPtr == nullptr) {
return ge::GRAPH_FAILED;
}
auto valueShape = valueTensorPtr->GetStorageShape();
auto attnWeightShape = attnWeightsTensorPtr->GetStorageShape();
auto platformInfoPtr = context->GetPlatformInfo();
if (platformInfoPtr == nullptr) {
return ge::GRAPH_FAILED;
}
auto ascendPlatformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr);
uint32_t coreNum = ascendPlatformInfo.GetCoreNumAiv();
if (coreNum == 0) {
return ge::GRAPH_FAILED;
}
context->SetBlockDim(coreNum);
uint32_t batchSize = valueShape.GetDim(VALUE_BATCH_SIZE_DIM);
uint32_t numHeads = valueShape.GetDim(VALUE_NUM_HEADS_DIM);
uint32_t embedDims = valueShape.GetDim(VALUE_EMBED_DIMS_DIM);
uint32_t numKeys = valueShape.GetDim(VALUE_NUM_KEYS_DIM);
uint32_t numLevels = attnWeightShape.GetDim(ATTN_WEIGHTS_NUM_LEVELS_DIM);
uint32_t numQueries = attnWeightShape.GetDim(ATTN_WEIGHTS_NUM_QUERIES_DIM);
uint32_t numPoints = attnWeightShape.GetDim(ATTN_WEIGHTS_NUM_POINTS_DIM);
uint32_t numItemsPerBlock = BLOCK_BYTES / FLOAT32_BYTES;
uint32_t numLevelsAligned = AlignUp(numLevels, numItemsPerBlock);
uint32_t numPointsAligned = AlignUp(numPoints, numItemsPerBlock);
uint32_t taskNum = batchSize * numQueries;
uint32_t taskPerCore = taskNum / coreNum;
uint32_t taskCoreTail = taskNum % coreNum;
uint64_t ubBytesTotal;
ascendPlatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubBytesTotal);
uint64_t ubBytes = ubBytesTotal - UB_RESERVE_BYTES;
uint64_t ubSize = ubBytes / FLOAT32_BYTES;
uint32_t ubSize4Others = 5 * numLevelsAligned + 3 * numHeads * numLevels * numPointsAligned * embedDims;
uint32_t querySizePerTask = 10 * numHeads * numLevels * numPointsAligned + numHeads * embedDims;
uint32_t taskCompNum = (ubSize - ubSize4Others - numItemsPerBlock) / querySizePerTask;
tiling.set_batchSize(batchSize);
tiling.set_numHeads(numHeads);
tiling.set_embedDims(embedDims);
tiling.set_numKeys(numKeys);
tiling.set_numLevels(numLevels);
tiling.set_numQueries(numQueries);
tiling.set_numPoints(numPoints);
tiling.set_coreNum(coreNum);
tiling.set_taskPerCore(taskPerCore);
tiling.set_taskCompNum(taskCompNum);
tiling.set_taskCoreTail(taskCoreTail);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForGeometricKernelAttnGrad(gert::InferShapeContext* context)
{
const gert::Shape* valueShape = context->GetInputShape(POS_INPUT_VALUE);
const gert::Shape* attnWeightsShape = context->GetInputShape(POS_INPUT_ATTN_WEIGHTS);
if (valueShape == nullptr || attnWeightsShape == nullptr) {
return ge::GRAPH_FAILED;
}
gert::Shape* gradValueShape = context->GetOutputShape(POS_OUTPUT_GRAD_VALUE);
gert::Shape* gradAttnWeightsShape = context->GetOutputShape(POS_OUTPUT_GRAD_ATTN_WEIGHTS);
if ((gradValueShape == nullptr) || (gradAttnWeightsShape == nullptr)) {
return ge::GRAPH_FAILED;
}
*gradValueShape = *valueShape;
*gradAttnWeightsShape = *attnWeightsShape;
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForGeometricKernelAttnGrad(gert::InferDataTypeContext* context)
{
const ge::DataType value_dtype = context->GetInputDataType(POS_INPUT_VALUE);
const ge::DataType attn_weights_dtype = context->GetInputDataType(POS_INPUT_ATTN_WEIGHTS);
context->SetOutputDataType(POS_OUTPUT_GRAD_VALUE, value_dtype);
context->SetOutputDataType(POS_OUTPUT_GRAD_ATTN_WEIGHTS, attn_weights_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class GeometricKernelAttnGrad : public OpDef {
public:
explicit GeometricKernelAttnGrad(const char* name) : OpDef(name)
{
this->Input("value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("spatial_shapes")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("level_start_index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("sampling_locations")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("attn_weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("grad_output")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("grad_value")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("grad_attn_weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->SetInferShape(ge::InferShapeForGeometricKernelAttnGrad)
.SetInferDataType(ge::InferDataTypeForGeometricKernelAttnGrad);
this->AICore().SetTiling(optiling::TilingFuncForGeometricKernelAttnGrad);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(GeometricKernelAttnGrad);
}