* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include "border_align_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include <cmath>
using namespace std;
namespace optiling {
const uint32_t TILE_NUM = 8;
const uint32_t BOX_INFO = 4;
static ge::graphStatus TilingForBorderAlign(gert::TilingContext *context) {
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
BorderAlignTilingData tiling;
if (context->GetInputShape(0) == nullptr) {
return ge::GRAPH_FAILED;
}
if (context->GetInputShape(1) == nullptr) {
return ge::GRAPH_FAILED;
}
if (context->GetOutputShape(0) == nullptr) {
return ge::GRAPH_FAILED;
}
auto inputShape = context->GetInputShape(0)->GetStorageShape();
auto roisShape = context->GetInputShape(1)->GetStorageShape();
auto outputShape = context->GetOutputShape(0)->GetStorageShape();
uint32_t batchSize = inputShape.GetDim(0);
uint32_t inputH = inputShape.GetDim(1);
uint32_t inputW = inputShape.GetDim(2);
uint32_t channels = inputShape.GetDim(3);
if (channels % BOX_INFO != 0) {
return ge::GRAPH_FAILED;
}
uint32_t channelsAligned;
if (static_cast<uint32_t>((channels / BOX_INFO) % TILE_NUM) == 0) {
channelsAligned = channels;
} else {
channelsAligned = (static_cast<uint32_t>(channels / BOX_INFO / TILE_NUM) + 1) * TILE_NUM * BOX_INFO;
}
auto attrsPtr = context->GetAttrs();
if (attrsPtr == nullptr) {
return ge::GRAPH_FAILED;
}
int32_t pooledSize = *(attrsPtr->GetAttrPointer<int32_t>(0));
uint32_t roisNum = roisShape.GetDim(0) * roisShape.GetDim(1);
if (roisNum == 0) {
return ge::GRAPH_FAILED;
}
auto platform = context->GetPlatformInfo();
if (platform == nullptr) {
return ge::GRAPH_FAILED;
}
auto platform_info = platform_ascendc::PlatformAscendC(platform);
uint32_t BLOCK_DIM = platform_info.GetCoreNumAiv();
if (BLOCK_DIM == 0) {
return ge::GRAPH_FAILED;
}
uint32_t roisNumAligned;
if (static_cast<uint32_t>(roisNum % TILE_NUM) == 0) {
roisNumAligned = roisNum;
} else {
roisNumAligned = (static_cast<uint32_t>(roisNum / TILE_NUM) + 1) * TILE_NUM;
}
uint32_t tailNum = roisNumAligned - roisNum;
uint32_t roisNumPerScore = (roisNumAligned / BLOCK_DIM / TILE_NUM) * TILE_NUM;
uint32_t roisNumPerLcore = roisNumPerScore + TILE_NUM;
uint32_t scoreNum = (BLOCK_DIM * (TILE_NUM + roisNumPerScore) - roisNumAligned) / TILE_NUM;
uint32_t lcoreNum = BLOCK_DIM - scoreNum;
if (roisNumPerScore == 0) {
BLOCK_DIM = BLOCK_DIM - scoreNum;
}
if (roisNumPerLcore == 0) {
BLOCK_DIM = BLOCK_DIM - lcoreNum;
}
uint32_t inputBufferSize = channelsAligned / BOX_INFO * sizeof(float);
uint32_t roisNumPerLoop = 64;
uint32_t roisBufferSize = roisNumPerLoop * BOX_INFO * sizeof(float);
uint32_t moveInLength = channelsAligned / BOX_INFO;
uint32_t moveOutLength = channels / BOX_INFO * sizeof(float);
tiling.set_roisNumPerLoop(roisNumPerLoop);
tiling.set_batchSize(batchSize);
tiling.set_inputH(inputH);
tiling.set_inputW(inputW);
tiling.set_channels(channels);
tiling.set_moveInLength(moveInLength);
tiling.set_moveOutLength(moveOutLength);
tiling.set_roisNumAligned(roisNumAligned);
tiling.set_tailNum(tailNum);
tiling.set_pooledSize(pooledSize);
tiling.set_roisNumPerLcore(roisNumPerLcore);
tiling.set_roisNumPerScore(roisNumPerScore);
tiling.set_lcoreNum(lcoreNum);
tiling.set_scoreNum(scoreNum);
tiling.set_inputBufferSize(inputBufferSize);
tiling.set_roisBufferSize(roisBufferSize);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
context->SetBlockDim(BLOCK_DIM);
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
const gert::Shape *inputShape = context->GetInputShape(0);
const gert::Shape *roisShape = context->GetInputShape(1);
gert::Shape *outputShape = context->GetOutputShape(0);
if (inputShape == nullptr) {
return ge::GRAPH_FAILED;
}
if (roisShape == nullptr) {
return ge::GRAPH_FAILED;
}
if (outputShape == nullptr) {
return ge::GRAPH_FAILED;
}
int64_t batchSize = inputShape->GetDim(0);
int64_t heightTimesWidth = roisShape->GetDim(1);
int64_t channels = inputShape->GetDim(3);
auto attrsPtr = context->GetAttrs();
if (attrsPtr == nullptr) {
return ge::GRAPH_FAILED;
}
uint32_t pooledSize = *(attrsPtr->GetAttrPointer<uint32_t>(0));
*outputShape = {batchSize, heightTimesWidth, pooledSize + 1, channels};
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeBorderAlign(gert::InferDataTypeContext *context) {
const ge::DataType valueDtype = context->GetInputDataType(0);
context->SetOutputDataType(0, valueDtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class BorderAlign : public OpDef {
public:
explicit BorderAlign(const char *name) : OpDef(name) {
this->Input("input")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("rois")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("pooledSize").AttrType(REQUIRED).Int();
this->SetInferShape(ge::InferShape).SetInferDataType(ge::InferDataTypeBorderAlign);
this->AICore().SetTiling(optiling::TilingForBorderAlign);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(BorderAlign);
}