Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include "gaussian/op_host/gaussian_tiling.h"
#include "ge/utils.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
using namespace std;
namespace {
}
namespace optiling {
constexpr uint32_t RESERVED_UB_SIZE = 8 * 1024;
constexpr uint32_t CURRENT_WORKSPACE_SIZE = 16 * 1024 * 1024;
constexpr uint32_t FEATURE_MAP_STRIDE_DIM = 0;
constexpr uint32_t MIN_OVERLAP_DIM = 1;
constexpr uint32_t MIN_RADIUS_DIM = 2;
constexpr uint32_t NUM_MAX_OBJS_DIM = 3;
constexpr uint32_t VOXEL_X_SIZE_DIM = 4;
constexpr uint32_t VOXEL_Y_SIZE_DIM = 5;
constexpr uint32_t PRC_X_DIM = 6;
constexpr uint32_t PRC_Y_DIM = 7;
constexpr uint32_t FEATURE_MAP_SIZE_X_DIM = 8;
constexpr uint32_t FEATURE_MAP_SIZE_Y_DIM = 9;
constexpr uint32_t NORM_BBOX_DIM = 10;
constexpr uint32_t FLIP_ANGLE_DIM = 11;
static ge::graphStatus TilingFuncForGaussian(gert::TilingContext* context)
{
CHECK_NULLPTR(context);
auto platform = context->GetPlatformInfo();
CHECK_NULLPTR(platform);
auto platformInfo = platform_ascendc::PlatformAscendC(platform);
uint32_t coreNum = platformInfo.GetCoreNumAiv();
uint64_t ubSize;
platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
GaussianTilingData tiling;
auto gtBoxesPtr = context->GetInputTensor(0);
CHECK_NULLPTR(gtBoxesPtr);
auto gtBoxesShape = gtBoxesPtr->GetStorageShape();
uint32_t dimSize = gtBoxesShape.GetDim(0);
uint32_t totalCoreTaskNum = gtBoxesShape.GetDim(1);
uint32_t numObjs = totalCoreTaskNum;
auto attrsPtr = context->GetAttrs();
CHECK_NULLPTR(attrsPtr);
uint32_t featureMapStride = *(attrsPtr->GetAttrPointer<int32_t>(FEATURE_MAP_STRIDE_DIM));
float minOverLap = *(attrsPtr->GetAttrPointer<float>(MIN_OVERLAP_DIM));
uint32_t minRadius = *(attrsPtr->GetAttrPointer<int32_t>(MIN_RADIUS_DIM));
uint32_t numMaxObjs = *(attrsPtr->GetAttrPointer<int32_t>(NUM_MAX_OBJS_DIM));
float voxelXSize = *(attrsPtr->GetAttrPointer<float>(VOXEL_X_SIZE_DIM));
float voxelYSize = *(attrsPtr->GetAttrPointer<float>(VOXEL_Y_SIZE_DIM));
float prcX = *(attrsPtr->GetAttrPointer<float>(PRC_X_DIM));
float prcY = *(attrsPtr->GetAttrPointer<float>(PRC_Y_DIM));
int32_t featureMapSizeX = *(attrsPtr->GetAttrPointer<int32_t>(FEATURE_MAP_SIZE_X_DIM));
int32_t featureMapSizeY = *(attrsPtr->GetAttrPointer<int32_t>(FEATURE_MAP_SIZE_Y_DIM));
bool normBbox = *(attrsPtr->GetAttrPointer<bool>(NORM_BBOX_DIM));
bool flipAngle = *(attrsPtr->GetAttrPointer<bool>(FLIP_ANGLE_DIM));
if (totalCoreTaskNum > numMaxObjs) {
totalCoreTaskNum = numMaxObjs;
}
uint32_t coreProcessTaskNum = AlignUp(Ceil(totalCoreTaskNum, coreNum), 8);
uint32_t lastCoreProcessTaskNum = Tail(totalCoreTaskNum, coreProcessTaskNum);
uint32_t usedCoreNum = Ceil(totalCoreTaskNum, coreProcessTaskNum);
uint32_t singleProcessTaskNum = (ubSize - RESERVED_UB_SIZE) / 4 / (dimSize * 2 + 32);
context->SetBlockDim(usedCoreNum);
tiling.set_usedCoreNum(usedCoreNum);
tiling.set_numObjs(numObjs);
tiling.set_totalCoreTaskNum(totalCoreTaskNum);
tiling.set_coreProcessTaskNum(coreProcessTaskNum);
tiling.set_lastCoreProcessTaskNum(lastCoreProcessTaskNum);
tiling.set_singleProcessTaskNum(singleProcessTaskNum);
tiling.set_featureMapSizeX(featureMapSizeX);
tiling.set_featureMapSizeY(featureMapSizeY);
tiling.set_voxelXSize(voxelXSize);
tiling.set_voxelYSize(voxelYSize);
tiling.set_prcX(prcX);
tiling.set_prcY(prcY);
tiling.set_featureMapStride(featureMapStride);
tiling.set_numMaxObjs(numMaxObjs);
tiling.set_minRadius(minRadius);
tiling.set_minOverLap(minOverLap);
tiling.set_dimSize(dimSize);
tiling.set_normBbox(normBbox);
tiling.set_flipAngle(flipAngle);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
if (currentWorkspace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkspace[0] = CURRENT_WORKSPACE_SIZE;
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
constexpr uint32_t CENTER_INT_SHAPE_DIM = 0;
constexpr uint32_t RADIUS_SHAPE_DIM = 1;
constexpr uint32_t MASK_SHAPE_DIM = 2;
constexpr uint32_t IND_SHAPE_DIM = 3;
constexpr uint32_t RET_BOXES_SHAPE_DIM = 4;
static ge::graphStatus InferShapeForGaussian(gert::InferShapeContext* context)
{
CHECK_NULLPTR(context);
const gert::Shape* gtBoxesShape = context->GetInputShape(0);
auto attrsPtr = context->GetAttrs();
CHECK_NULLPTR(gtBoxesShape);
CHECK_NULLPTR(attrsPtr);
gert::Shape* centerIntShape = context->GetOutputShape(CENTER_INT_SHAPE_DIM);
gert::Shape* radiusShape = context->GetOutputShape(RADIUS_SHAPE_DIM);
gert::Shape* maskShape = context->GetOutputShape(MASK_SHAPE_DIM);
gert::Shape* indShape = context->GetOutputShape(IND_SHAPE_DIM);
gert::Shape* retBoxesShape = context->GetOutputShape(RET_BOXES_SHAPE_DIM);
CHECK_NULLPTR(centerIntShape);
CHECK_NULLPTR(radiusShape);
CHECK_NULLPTR(maskShape);
CHECK_NULLPTR(indShape);
CHECK_NULLPTR(retBoxesShape);
int64_t boxesMode = gtBoxesShape->GetDim(0);
int64_t numObjs = gtBoxesShape->GetDim(1);
uint32_t numMaxObjs = *(attrsPtr->GetAttrPointer<uint32_t>(3));
if (numObjs > numMaxObjs) {
numObjs = numMaxObjs;
}
*centerIntShape = {2, numObjs};
*radiusShape = {numObjs};
*maskShape = {numMaxObjs};
*indShape = {numMaxObjs};
*retBoxesShape = {numObjs, boxesMode + 1};
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForGaussian(gert::InferDataTypeContext* context)
{
context->SetOutputDataType(CENTER_INT_SHAPE_DIM, ge::DT_INT32);
context->SetOutputDataType(RADIUS_SHAPE_DIM, ge::DT_INT32);
context->SetOutputDataType(MASK_SHAPE_DIM, ge::DT_UINT8);
context->SetOutputDataType(IND_SHAPE_DIM, ge::DT_INT32);
context->SetOutputDataType(RET_BOXES_SHAPE_DIM, ge::DT_FLOAT);
return ge::GRAPH_SUCCESS;
}
}
namespace ops {
class Gaussian : public OpDef {
public:
explicit Gaussian(const char* name) : OpDef(name)
{
this->Input("gt_boxes")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("center_int")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("radius")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("mask")
.ParamType(REQUIRED)
.DataType({ge::DT_UINT8})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("ind")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("ret_boxes")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("feature_map_stride").AttrType(REQUIRED).Int();
this->Attr("gaussian_overlap").AttrType(REQUIRED).Float();
this->Attr("min_radius").AttrType(REQUIRED).Int();
this->Attr("num_max_objs").AttrType(REQUIRED).Int();
this->Attr("voxel_size_x").AttrType(REQUIRED).Float();
this->Attr("voxel_size_y").AttrType(REQUIRED).Float();
this->Attr("pc_range_x").AttrType(REQUIRED).Float();
this->Attr("pc_range_y").AttrType(REQUIRED).Float();
this->Attr("feature_map_size_x").AttrType(REQUIRED).Int();
this->Attr("feature_map_size_y").AttrType(REQUIRED).Int();
this->Attr("norm_bbox").AttrType(REQUIRED).Bool();
this->Attr("flip_angle").AttrType(REQUIRED).Bool();
this->SetInferShape(ge::InferShapeForGaussian).SetInferDataType(ge::InferDataTypeForGaussian);
this->AICore().SetTiling(optiling::TilingFuncForGaussian);
this->AICore().AddConfig("ascend950");
}
};
OP_ADD(Gaussian);
}