* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#include <graph/ge_error_codes.h>
#include <graph/types.h>
#include <register/op_def.h>
#include <algorithm>
#include <cmath>
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "unique_voxel_tiling.h"
namespace {
constexpr size_t POINT_IDX = 0;
constexpr int32_t RESERVE_UB = 10 * 1024;
constexpr int32_t COEF = 49;
constexpr int32_t ONE_REPEAT_FLOAT_SIZE = 64;
int32_t AlignDown(int32_t a, int32_t b)
{
if (b == 0) {
return a;
}
return a / b * b;
}
int32_t AlignUp(int32_t a, int32_t b)
{
if (b == 0) {
return a;
}
return (a + b - 1) / b * b;
}
ge::graphStatus TaskSchedule(gert::TilingContext* context, optiling::UniqueVoxelTilingData& tilingData)
{
auto platformInfo = context->GetPlatformInfo();
if (!platformInfo) {
return ge::GRAPH_FAILED;
}
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
int32_t core_num = ascendcPlatform.GetCoreNumAiv();
uint64_t ubSize;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
int32_t avgPts = AlignDown((ubSize - RESERVE_UB) / COEF, ONE_REPEAT_FLOAT_SIZE);
auto pointShape = context->GetInputShape(POINT_IDX);
if (!pointShape) {
return ge::GRAPH_FAILED;
}
int32_t totalPts = pointShape->GetStorageShape().GetDim(0);
avgPts = std::min(avgPts, AlignUp(totalPts, ONE_REPEAT_FLOAT_SIZE));
if (avgPts == 0) {
return ge::GRAPH_FAILED;
}
int32_t tailPts = totalPts % avgPts;
int32_t totalTasks = totalPts / avgPts + (tailPts > 0 ? 1 : 0);
tailPts = tailPts == 0 ? avgPts : tailPts;
int32_t usedBlkNum = std::min(core_num, totalTasks);
if (usedBlkNum == 0) {
return ge::GRAPH_FAILED;
}
int32_t avgTasks = totalTasks / usedBlkNum;
int32_t tailTasks = totalTasks % usedBlkNum;
tilingData.set_usedBlkNum(usedBlkNum);
tilingData.set_avgTasks(avgTasks);
tilingData.set_tailTasks(tailTasks);
tilingData.set_totalTasks(totalTasks);
tilingData.set_avgPts(avgPts);
tilingData.set_tailPts(tailPts);
tilingData.set_totalPts(totalPts);
context->SetBlockDim(usedBlkNum);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddWorkspace(gert::TilingContext* context, optiling::UniqueVoxelTilingData& tilingData)
{
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
uint32_t usrWorkspaceSize = tilingData.get_usedBlkNum() * sizeof(int32_t) * 2;
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
if (currentWorkspace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkspace[0] = sysWorkspaceSize + usrWorkspaceSize;
return ge::GRAPH_SUCCESS;
}
}
namespace optiling {
static ge::graphStatus TilingForUniqueVoxel(gert::TilingContext* context)
{
if (!context) {
return ge::GRAPH_FAILED;
}
UniqueVoxelTilingData tilingData;
if (TaskSchedule(context, tilingData) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (AddWorkspace(context, tilingData) != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForUniqueVoxel(gert::InferShapeContext* context)
{
const gert::Shape* pointShape = context->GetInputShape(POINT_IDX);
gert::Shape* uniVoxShape = context->GetOutputShape(0);
gert::Shape* uniIdxShape = context->GetOutputShape(1);
gert::Shape* uniArgsortIdxShape = context->GetOutputShape(2);
gert::Shape* voxNumShape = context->GetOutputShape(3);
if (!pointShape || !uniVoxShape || !uniArgsortIdxShape || !uniIdxShape || !voxNumShape) {
return ge::GRAPH_FAILED;
}
if (context->GetInputShape(1) == nullptr || context->GetInputShape(2) == nullptr) {
return ge::GRAPH_FAILED;
}
*uniVoxShape = *pointShape;
*uniIdxShape = *pointShape;
*uniArgsortIdxShape = *pointShape;
*voxNumShape = {1};
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForUniqueVoxel(gert::InferDataTypeContext* context)
{
context->SetOutputDataType(0, ge::DT_INT32);
context->SetOutputDataType(1, ge::DT_INT32);
context->SetOutputDataType(2, ge::DT_INT32);
context->SetOutputDataType(3, ge::DT_INT32);
return GRAPH_SUCCESS;
}
}
namespace ops {
class UniqueVoxel : public OpDef {
public:
explicit UniqueVoxel(const char* name) : OpDef(name)
{
this->Input("voxels")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("argsort_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("uni_voxels")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("uni_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("uni_argsort_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("voxel_num")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->SetInferShape(ge::InferShapeForUniqueVoxel).SetInferDataType(ge::InferDataTypeForUniqueVoxel);
this->AICore().SetTiling(optiling::TilingForUniqueVoxel);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(UniqueVoxel);
}