#include <graph/types.h>
#include <register/op_def.h>
#include <cstdint>
#include "bev_pool_tiling.h"
#include "ge/utils.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
namespace {
constexpr size_t FEAT_IDX = 0;
constexpr size_t GEOM_FEAT_IDX = 1;
constexpr size_t INTERVAL_IDX_V1 = 3;
constexpr size_t INTERVAL_IDX_V2 = 6;
constexpr size_t B_IDX = 0;
constexpr size_t D_IDX = 1;
constexpr size_t H_IDX = 2;
constexpr size_t W_IDX = 3;
constexpr size_t C_IDX = 4;
constexpr int32_t TILING_ALIGN32B_FLAG = 1;
constexpr int32_t TILING_FP32_BIT = 1;
constexpr int32_t TILING_FP16_BIT = 2;
constexpr int32_t TILING_BF16_BIT = 3;
int32_t GetTilingKey(const ge::DataType dtype, optiling::BEVPoolTilingData &tiling) {
auto dtypeBytes = ge::GetSizeByDataType(dtype);
int32_t cBytes = tiling.get_stride0() * dtypeBytes;
int32_t key = cBytes % 32 == 0 ? TILING_ALIGN32B_FLAG : 0;
switch (dtype) {
case ge::DT_FLOAT:
key |= 1 << TILING_FP32_BIT;
break;
case ge::DT_FLOAT16:
key |= 1 << TILING_FP16_BIT;
break;
case ge::DT_BF16:
key |= 1 << TILING_BF16_BIT;
break;
default:
break;
}
return key;
}
enum Version { V1, V2 };
}
namespace optiling {
template <Version version> static ge::graphStatus TilingForBEVPool(gert::TilingContext *context) {
CHECK_NULLPTR(context);
BEVPoolTilingData tiling;
CHECK_NULLPTR(context->GetPlatformInfo());
auto platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
auto coreNum = platform.GetCoreNum();
auto intervalShape =
version == V1 ? context->GetInputShape(INTERVAL_IDX_V1) : context->GetInputShape(INTERVAL_IDX_V2);
if (intervalShape == nullptr) {
return ge::GRAPH_FAILED;
}
uint64_t nInterval = intervalShape->GetStorageShape().GetDim(0);
uint64_t usedCoreNum = std::min(static_cast<uint64_t>(coreNum), nInterval);
tiling.set_usedCoreNum(usedCoreNum);
if (usedCoreNum == 0) {
return ge::GRAPH_FAILED;
}
auto avgTaskNum = nInterval / usedCoreNum;
auto tailTaskNum = nInterval % usedCoreNum;
tiling.set_totalTaskNum(nInterval);
tiling.set_avgTaskNum(avgTaskNum);
tiling.set_tailTaskNum(tailTaskNum);
auto attrs = context->GetAttrs();
if (!attrs) {
return ge::GRAPH_FAILED;
}
auto getAttr = [attrs](size_t idx) -> uint64_t {
auto ptr = attrs->GetInt(idx);
if (!ptr) {
return ge::GRAPH_FAILED;
}
return static_cast<uint64_t>(*ptr);
};
auto b = getAttr(B_IDX);
auto d = getAttr(D_IDX);
auto h = getAttr(H_IDX);
auto w = getAttr(W_IDX);
auto c = getAttr(C_IDX);
if (b < 0 || d < 0 || h < 0 || w < 0 || c < 0) {
return ge::GRAPH_FAILED;
}
tiling.set_stride0(c);
tiling.set_stride1(w * c);
tiling.set_stride2(h * w * c);
tiling.set_stride3(d * h * w * c);
auto dtype = context->GetInputDesc(FEAT_IDX)->GetDataType();
context->SetTilingKey(GetTilingKey(dtype, tiling));
context->SetBlockDim(usedCoreNum);
ADD_TILING_DATA(context, tiling)
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static graphStatus InferShapeForBEVPool(gert::InferShapeContext *context) {
auto attrs = context->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
auto getAttr = [attrs](size_t idx) -> int64_t {
auto ptr = attrs->GetInt(idx);
if (!ptr) {
return ge::GRAPH_FAILED;
}
return static_cast<int64_t>(*ptr);
};
auto b = getAttr(B_IDX);
auto d = getAttr(D_IDX);
auto h = getAttr(H_IDX);
auto w = getAttr(W_IDX);
auto c = getAttr(C_IDX);
if (b < 0 || d < 0 || h < 0 || w < 0 || c < 0) {
return ge::GRAPH_FAILED;
}
gert::Shape *outShape = context->GetOutputShape(0);
if (outShape == nullptr) {
return ge::GRAPH_FAILED;
}
*outShape = {b, d, h, w, c};
return GRAPH_SUCCESS;
}
static graphStatus InferShapeForBEVPoolGrad(gert::InferShapeContext *context) {
CHECK_NULLPTR(context);
const gert::Shape *GeomFeatShape = context->GetInputShape(GEOM_FEAT_IDX);
CHECK_NULLPTR(GeomFeatShape);
const auto n = GeomFeatShape->GetDim(0);
auto attrs = context->GetAttrs();
CHECK_NULLPTR(attrs);
auto c_ptr = attrs->GetInt(C_IDX);
CHECK_NULLPTR(c_ptr);
auto c = *c_ptr;
gert::Shape *gradFeatShape = context->GetOutputShape(0);
CHECK_NULLPTR(gradFeatShape);
*gradFeatShape = {n, c};
return GRAPH_SUCCESS;
}
static graphStatus InferShapeForBEVPoolV2Grad(gert::InferShapeContext *context) {
CHECK_NULLPTR(context);
gert::Shape *gradDepthShape = context->GetOutputShape(0);
const gert::Shape *depthShape = context->GetInputShape(1);
CHECK_NULLPTR(gradDepthShape);
CHECK_NULLPTR(depthShape);
*gradDepthShape = *depthShape;
gert::Shape *gradFeatShape = context->GetOutputShape(1);
const gert::Shape *featShape = context->GetInputShape(2);
CHECK_NULLPTR(gradFeatShape);
CHECK_NULLPTR(featShape);
*gradFeatShape = *featShape;
return GRAPH_SUCCESS;
}
}
namespace ops {
class BEVPoolV2 : public OpDef {
public:
explicit BEVPoolV2(const char *name) : OpDef(name) {
this->Input("depth")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("ranks_depth")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("ranks_feat")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("ranks_bev")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("interval_lengths")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("interval_starts")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("b").AttrType(REQUIRED).Int();
this->Attr("d").AttrType(REQUIRED).Int();
this->Attr("h").AttrType(REQUIRED).Int();
this->Attr("w").AttrType(REQUIRED).Int();
this->Attr("c").AttrType(REQUIRED).Int();
this->SetInferShape(ge::InferShapeForBEVPool);
this->AICore().SetTiling(optiling::TilingForBEVPool<V2>);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
* @brief: BEVPoolGrad, the backward of bev_pool
* @par Inputs:
* grad_out: input grad, 5D tensor(b, d, h, w, c), dtype: float32, format:
* NDHWC, ND geom_feat: input coords, 2D tensor(n, 4), dtype: int32, format: ND
* interval_starts: starting position for pooled point, 1D tensor(n_interval),
* dtype: int32, format: ND interval_lengths: the number of points in each
* interval, 1D tensor(n_interval), dtype: int32, format: ND
* @par Outputs:
* grad_feat: output grad, 2D tensor(n, c), dtype: float32, format: ND
* @par Attributes:
**/
class BEVPoolV2Grad : public OpDef {
public:
explicit BEVPoolV2Grad(const char *name) : OpDef(name) {
this->Input("grad_out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("depth")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("ranks_depth")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("ranks_feat")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("ranks_bev")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("interval_lengths")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("interval_starts")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("grad_depth")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("grad_feat")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("b").AttrType(REQUIRED).Int();
this->Attr("d").AttrType(REQUIRED).Int();
this->Attr("h").AttrType(REQUIRED).Int();
this->Attr("w").AttrType(REQUIRED).Int();
this->Attr("c").AttrType(REQUIRED).Int();
this->SetInferShape(ge::InferShapeForBEVPoolV2Grad);
this->AICore().SetTiling(optiling::TilingForBEVPool<V2>);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(BEVPoolV2);
OP_ADD(BEVPoolV2Grad);
}