#include "ge/utils.h"
#include "deformable_aggregation_grad_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
using namespace ge;
using namespace std;
namespace {
constexpr uint32_t SINGLE = 1;
constexpr uint32_t BYTE_BLOCK = 32;
constexpr uint32_t SIZE_OF_FP32 = 4;
constexpr uint32_t SIZE_OF_FP16 = 2;
const uint32_t INPUT_FEAT = 0;
const uint32_t INPUT_SPATIAL_SHAPE = 1;
const uint32_t INPUT_SAMPLING_LOCATION = 3;
const uint32_t INPUT_WEIGHT = 4;
const uint32_t BATCH_SIZE_DIM = 0;
const uint32_t NUM_FEAT_DIM = 1;
const uint32_t NUM_EMBEDS_DIM = 2;
const uint32_t NUM_CAMS_DIM = 0;
const uint32_t NUM_SCALE_DIM = 1;
const uint32_t NUM_ANCHORS_DIM = 1;
const uint32_t NUM_POINTS_DIM = 2;
const uint32_t NUM_GROUPS_DIM = 5;
}
namespace optiling {
static ge::graphStatus TilingForDeformableAggregationGrad(gert::TilingContext* context)
{
DeformableAggregationGradTilingData tiling;
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto featTensorPtr = context->GetInputTensor(INPUT_FEAT);
auto spatialShapeTensorPtr = context->GetInputTensor(INPUT_SPATIAL_SHAPE);
auto samplingLocationTensorPtr = context->GetInputTensor(INPUT_SAMPLING_LOCATION);
auto WeightTensorPtr = context->GetInputTensor(INPUT_WEIGHT);
CHECK_NULLPTR(featTensorPtr);
CHECK_NULLPTR(spatialShapeTensorPtr);
CHECK_NULLPTR(samplingLocationTensorPtr);
CHECK_NULLPTR(WeightTensorPtr);
auto featShape = featTensorPtr->GetStorageShape();
auto spatialShapeShape = spatialShapeTensorPtr->GetStorageShape();
auto samplingLocationShape = samplingLocationTensorPtr->GetStorageShape();
auto weightShape = WeightTensorPtr->GetStorageShape();
auto platformInfo = context->GetPlatformInfo();
CHECK_NULLPTR(platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
static uint32_t coreNum = ascendcPlatform.GetCoreNumAiv();
if (coreNum == 0) {
return ge::GRAPH_FAILED;
}
uint32_t batchSize = featShape.GetDim(BATCH_SIZE_DIM);
uint32_t numFeat = featShape.GetDim(NUM_FEAT_DIM);
uint32_t numEmbeds = featShape.GetDim(NUM_EMBEDS_DIM);
uint32_t numCams = spatialShapeShape.GetDim(NUM_CAMS_DIM);
uint32_t numScale = spatialShapeShape.GetDim(NUM_SCALE_DIM);
uint32_t numAnchors = samplingLocationShape.GetDim(NUM_ANCHORS_DIM);
uint32_t numPoints = samplingLocationShape.GetDim(NUM_POINTS_DIM);
uint32_t numGroups = weightShape.GetDim(NUM_GROUPS_DIM);
uint32_t usedCoreNum = coreNum;
uint32_t totalTask = batchSize * numAnchors;
uint32_t avgWeightNum = Ceil(totalTask, usedCoreNum);
uint32_t tailWeightNum = Tail(totalTask, avgWeightNum);
usedCoreNum = Ceil(totalTask, avgWeightNum);
bool dtype = context->GetInputDesc(INPUT_FEAT)->GetDataType() == ge::DT_FLOAT;
uint32_t dataTypeSize = dtype ? SIZE_OF_FP32 : SIZE_OF_FP16;
uint64_t ubSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
uint64_t usedUbSize = (10 * 1024 + 15 * numEmbeds + 2 * numScale * numEmbeds + 2 * numScale * numGroups + 2 * numPoints * numCams) * dataTypeSize;
uint32_t singleProcessTaskLen = (ubSize - usedUbSize) / dataTypeSize / numEmbeds;
context->SetBlockDim(usedCoreNum);
tiling.set_usedCoreNum(usedCoreNum);
tiling.set_avgWeightNum(avgWeightNum);
tiling.set_tailWeightNum(tailWeightNum);
tiling.set_singleProcessTaskLen(singleProcessTaskLen);
tiling.set_numPoints(numPoints);
tiling.set_numCams(numCams);
tiling.set_numScale(numScale);
tiling.set_numGroups(numGroups);
tiling.set_numEmbeds(numEmbeds);
tiling.set_numFeat(numFeat);
tiling.set_numAnchors(numAnchors);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForDeformableAggregationGrad(gert::InferShapeContext* context)
{
const gert::Shape* featShape = context->GetInputShape(INPUT_FEAT);
const gert::Shape* samplingLocationShape = context->GetInputShape(INPUT_SAMPLING_LOCATION);
const gert::Shape* weightShape = context->GetInputShape(INPUT_WEIGHT);
CHECK_NULLPTR(featShape);
CHECK_NULLPTR(samplingLocationShape);
CHECK_NULLPTR(weightShape);
gert::Shape* grad_mc_ms_feat_shape = context->GetOutputShape(0);
gert::Shape* grad_sampling_location_shape = context->GetOutputShape(1);
gert::Shape* grad_weight_shape = context->GetOutputShape(2);
CHECK_NULLPTR(grad_mc_ms_feat_shape);
CHECK_NULLPTR(grad_sampling_location_shape);
CHECK_NULLPTR(grad_weight_shape);
*grad_mc_ms_feat_shape = *featShape;
*grad_sampling_location_shape = *samplingLocationShape;
*grad_weight_shape = *weightShape;
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForDeformableAggregationGrad(gert::InferDataTypeContext* context)
{
const ge::DataType value_dtype = context->GetInputDataType(0);
context->SetOutputDataType(0, value_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class DeformableAggregationGrad : public OpDef {
public:
explicit DeformableAggregationGrad(const char* name) : OpDef(name)
{
this->Input("mc_ms_feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("spatial_shape")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale_start_index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("sampling_location")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("grad_output")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("grad_mc_ms_feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("grad_sampling_location")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("grad_weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->SetInferShape(ge::InferShapeForDeformableAggregationGrad)
.SetInferDataType(ge::InferDataTypeForDeformableAggregationGrad);
this->AICore().SetTiling(optiling::TilingForDeformableAggregationGrad);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(DeformableAggregationGrad);
}