#include "deformable_aggregation_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "common/op_host/common.h"
namespace {
constexpr uint32_t SINGLE = 1;
constexpr uint32_t BYTE_BLOCK = 32;
constexpr uint32_t SIZE_OF_FP32 = 4;
constexpr uint32_t SIZE_OF_FP16 = 2;
constexpr uint32_t BATCH_SIZE_IDX = 0;
constexpr uint32_t FEAT_IDX = 1;
constexpr uint32_t EMBEDS_IDX = 2;
constexpr uint32_t ANCHORS_IDX = 3;
constexpr uint32_t PTS_IDX = 4;
constexpr uint32_t CAMS_IDX = 5;
constexpr uint32_t SCALE_IDX = 6;
constexpr uint32_t GROUPS_IDX = 7;
const float UB_RATIO = 1.0;
const uint32_t REVERSED_MEM = 80 * 1024;
}
namespace optiling {
static ge::graphStatus TilingForDeformableAggregation(gert::TilingContext* context)
{
DeformableAggregationTilingData tiling;
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto platformInfo = context->GetPlatformInfo();
if (platformInfo == nullptr) {
return ge::GRAPH_FAILED;
}
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
static uint32_t coreNum = ascendcPlatform.GetCoreNumAiv();
if (coreNum == 0) {
return ge::GRAPH_FAILED;
}
auto attrs = context->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
auto getAttr = [attrs](size_t idx) -> int32_t {
auto ptr = attrs->GetInt(idx);
if (!ptr) {
return -1;
}
return static_cast<int32_t>(*ptr);
};
auto bs = getAttr(BATCH_SIZE_IDX);
auto numFeats = getAttr(FEAT_IDX);
auto numEmbeds = getAttr(EMBEDS_IDX);
auto numAnchors = getAttr(ANCHORS_IDX);
auto numPoints = getAttr(PTS_IDX);
auto numCams = getAttr(CAMS_IDX);
auto numScales = getAttr(SCALE_IDX);
auto numGroups = getAttr(GROUPS_IDX);
bool dtype = context->GetInputDesc(0)->GetDataType() == ge::DT_FLOAT;
uint32_t dataByteNum = dtype ? SIZE_OF_FP32 : SIZE_OF_FP16;
uint32_t alignNum = BYTE_BLOCK / dataByteNum;
uint32_t cAligned = CeilAlign(static_cast<uint32_t>(numEmbeds), alignNum);
context->SetBlockDim(coreNum);
uint64_t ubSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
ubSize = ubSize - REVERSED_MEM;
context->SetLocalMemorySize(ubSize);
tiling.set_ubSize(ubSize);
tiling.set_bs(bs);
tiling.set_numFeats(numFeats);
tiling.set_numEmbeds(numEmbeds);
tiling.set_numAnchors(numAnchors);
tiling.set_numPoints(numPoints);
tiling.set_numCams(numCams);
tiling.set_numScales(numScales);
tiling.set_numGroups(numGroups);
tiling.set_cAligned(cAligned);
tiling.set_coreNum(coreNum);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForDeformableAggregation(gert::InferShapeContext* context)
{
auto attrs = context->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
auto getAttr = [attrs](size_t idx) -> int32_t {
auto ptr = attrs->GetInt(idx);
if (!ptr) {
return -1;
}
return static_cast<int32_t>(*ptr);
};
auto bs = getAttr(BATCH_SIZE_IDX);
auto anchor = getAttr(ANCHORS_IDX);
auto c = getAttr(EMBEDS_IDX);
gert::Shape* outShape = context->GetOutputShape(0);
*outShape = {bs, anchor, c};
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForDeformableAggregation(gert::InferDataTypeContext* context)
{
const ge::DataType value_dtype = context->GetInputDataType(0);
context->SetOutputDataType(0, value_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class DeformableAggregation : public OpDef {
public:
explicit DeformableAggregation(const char* name) : OpDef(name)
{
this->Input("mc_ms_feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("spatial_shape")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale_start_index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("sampling_location")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("weights")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("batch_size").AttrType(REQUIRED).Int();
this->Attr("num_feat").AttrType(REQUIRED).Int();
this->Attr("num_embeds").AttrType(REQUIRED).Int();
this->Attr("num_anchors").AttrType(REQUIRED).Int();
this->Attr("num_pts").AttrType(REQUIRED).Int();
this->Attr("num_cams").AttrType(REQUIRED).Int();
this->Attr("num_scale").AttrType(REQUIRED).Int();
this->Attr("num_groups").AttrType(REQUIRED).Int();
this->SetInferShape(ge::InferShapeForDeformableAggregation)
.SetInferDataType(ge::InferDataTypeForDeformableAggregation);
this->AICore().SetTiling(optiling::TilingForDeformableAggregation);
this->AICore().AddConfig("ascend950");
}
};
OP_ADD(DeformableAggregation);
}