* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
*/
#include <graph/types.h>
#include <log/log.h>
#include <register/op_def.h>
#include <cstdint>
#include "bev_pool_v3_tiling.h"
#include "ge/utils.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
namespace {
constexpr size_t INPUT_FEAT = 1;
constexpr size_t INPUT_FEAT_GRAD = 2;
constexpr size_t INPUT_RANKS_BEV = 4;
constexpr size_t INPUT_RANKS_BEV_GRAD = 5;
constexpr uint64_t RANK_NUM_PER_TASK = 1024;
constexpr int32_t ONE_BLK_SIZE = 8;
constexpr int32_t RESERVE_UB = 10 * 1024;
constexpr size_t ATTR_B_IDX = 1;
constexpr size_t ATTR_D_IDX = 2;
constexpr size_t ATTR_H_IDX = 3;
constexpr size_t ATTR_W_IDX = 4;
constexpr size_t ATTR_C_IDX = 5;
}
namespace optiling {
template <bool is_grad> static ge::graphStatus TilingForBEVPoolV3(gert::TilingContext *context) {
CHECK_NULLPTR(context);
BEVPoolV3TilingData tiling;
CHECK_NULLPTR(context->GetPlatformInfo());
auto platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint64_t ubSize;
platform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
auto coreNum = platform.GetCoreNum();
auto featShape = context->GetRequiredInputShape(is_grad ? INPUT_FEAT_GRAD : INPUT_FEAT);
auto ranksBevShape = context->GetRequiredInputShape(is_grad ? INPUT_RANKS_BEV_GRAD : INPUT_RANKS_BEV);
if (featShape == nullptr || ranksBevShape == nullptr) {
return ge::GRAPH_FAILED;
}
auto attrsPtr = context->GetAttrs();
CHECK_NULLPTR(attrsPtr);
auto withDepthPtr = attrsPtr->GetBool(0);
CHECK_NULLPTR(withDepthPtr);
bool withDepth = *withDepthPtr;
context->SetTilingKey(withDepth);
auto channel = featShape->GetOriginShape().GetDim(featShape->GetOriginShape().GetDimNum() - 1);
uint64_t ranks = ranksBevShape->GetOriginShape().GetDim(0);
uint64_t avgRankNum = withDepth
? RANK_NUM_PER_TASK
: (ubSize - RESERVE_UB) / (sizeof(float) * (channel + 1) * 2) / ONE_BLK_SIZE * ONE_BLK_SIZE;
avgRankNum = std::min(avgRankNum, ranks);
if (avgRankNum == 0) {
return ge::GRAPH_FAILED;
}
auto totalTaskNum = (ranks + avgRankNum - 1) / avgRankNum;
uint64_t usedCoreNum = std::min(static_cast<uint64_t>(coreNum), totalTaskNum);
if (usedCoreNum == 0) {
return ge::GRAPH_FAILED;
}
context->SetBlockDim(usedCoreNum);
auto avgTaskNum = totalTaskNum / usedCoreNum;
auto tailTaskNum = totalTaskNum % usedCoreNum;
auto tailRankNum = ranks - (totalTaskNum - 1) * avgRankNum;
tiling.set_usedCoreNum(usedCoreNum);
tiling.set_totalTaskNum(totalTaskNum);
tiling.set_avgTaskNum(avgTaskNum);
tiling.set_tailTaskNum(tailTaskNum);
tiling.set_avgRankNum(avgRankNum);
tiling.set_tailRankNum(tailRankNum);
tiling.set_channel(channel);
MX_DRIVING_LOGI("BEVPoolV3 tiling: usedCoreNum=%d, totalTaskNum=%d, avgTaskNum=%d, tailTaskNum=%d, avgRankNum=%d, "
"tailRankNum=%d, channel=%d",
usedCoreNum, totalTaskNum, avgTaskNum, tailTaskNum, avgRankNum, tailRankNum, channel);
ADD_TILING_DATA(context, tiling);
uint32_t sysWorkspaceSize = platform.GetLibApiWorkSpaceSize();
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
CHECK_NULLPTR(currentWorkspace);
currentWorkspace[0] = sysWorkspaceSize;
return ge::GRAPH_SUCCESS;
}
}
namespace ops {
static ge::graphStatus InferShapeForBEVPoolV3(gert::InferShapeContext *context) {
CHECK_NULLPTR(context);
auto attrs = context->GetAttrs();
auto getAttr = [attrs](size_t idx) -> int64_t {
auto ptr = attrs->GetInt(idx);
if (!ptr) {
return ge::GRAPH_FAILED;
}
return static_cast<int64_t>(*ptr);
};
auto b = getAttr(ATTR_B_IDX);
auto d = getAttr(ATTR_D_IDX);
auto h = getAttr(ATTR_H_IDX);
auto w = getAttr(ATTR_W_IDX);
auto c = getAttr(ATTR_C_IDX);
if (b <= 0 || d <= 0 || h <= 0 || w <= 0 || c <= 0) {
return ge::GRAPH_FAILED;
}
gert::Shape *outShape = context->GetOutputShape(0);
*outShape = {b, d, h, w, c};
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForBEVPoolV3(gert::InferDataTypeContext *context) {
CHECK_NULLPTR(context);
const auto outputDataType = context->GetRequiredInputDataType(1);
context->SetOutputDataType(0, outputDataType);
return ge::GRAPH_SUCCESS;
}
}
namespace ops {
class BEVPoolV3 : public OpDef {
public:
explicit BEVPoolV3(const char *name) : OpDef(name) {
this->Input("depth")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.AutoContiguous()
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.AutoContiguous()
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("ranks_depth")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.AutoContiguous()
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("ranks_feat")
.ParamType(OPTIONAL)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.AutoContiguous()
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("ranks_bev")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.AutoContiguous()
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("with_depth").Bool();
this->Attr("b").AttrType(REQUIRED).Int();
this->Attr("d").AttrType(REQUIRED).Int();
this->Attr("h").AttrType(REQUIRED).Int();
this->Attr("w").AttrType(REQUIRED).Int();
this->Attr("c").AttrType(REQUIRED).Int();
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->AICore().SetTiling(optiling::TilingForBEVPoolV3<false>);
this->AICore().AddConfig("ascend310p");
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
IMPL_OP_INFERSHAPE(BEVPoolV3).InferShape(InferShapeForBEVPoolV3).InferDataType(InferDataTypeForBEVPoolV3);
OP_ADD(BEVPoolV3);
}