* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
*/
#include "knn_tiling.h"
#include "common/op_host/common.h"
namespace optiling {
static ge::graphStatus TilingForKnn(gert::TilingContext *context) {
uint32_t batch;
uint32_t nPoint;
uint32_t nSource;
bool isFromKnn;
uint32_t coreNum;
int32_t k;
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
const gert::StorageShape *xyzShape = context->GetInputShape(0);
const gert::StorageShape *centerXyzShape = context->GetInputShape(1);
const gert::RuntimeAttrs *attr = context->GetAttrs();
auto platformInfoPtr = context->GetPlatformInfo();
if ((xyzShape == nullptr) || (centerXyzShape == nullptr) || (attr == nullptr) || (platformInfoPtr == nullptr) ||
(context->GetInputDesc(0) == nullptr)) {
return ge::GRAPH_FAILED;
}
if (attr->GetAttrPointer<uint32_t>(0) == nullptr) {
return ge::GRAPH_FAILED;
}
auto platformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr);
batch = centerXyzShape->GetStorageShape().GetDim(0);
nPoint = centerXyzShape->GetStorageShape().GetDim(1);
nSource = xyzShape->GetStorageShape().GetDim(2);
isFromKnn = *attr->GetAttrPointer<bool>(0);
k = *attr->GetAttrPointer<int32_t>(1);
coreNum = platformInfo.GetCoreNumAiv();
if (coreNum == 0) {
return ge::GRAPH_FAILED;
}
size_t sysWorkspaceSize = 16 * 1024 * 1024;
size_t *currentWorkSpace = context->GetWorkspaceSizes(1);
if (currentWorkSpace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkSpace[0] = sysWorkspaceSize;
KnnTilingData TilingData;
TilingData.set_batch(batch);
TilingData.set_nPoint(nPoint);
TilingData.set_nSource(nSource);
TilingData.set_isFromKnn(isFromKnn);
TilingData.set_coreNum(coreNum);
TilingData.set_k(k);
context->SetBlockDim(coreNum);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
TilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(TilingData.GetDataSize());
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InfershapeForKnn(gert::InferShapeContext *context) {
const gert::Shape *xyzShape = context->GetInputShape(0);
const gert::Shape *centerXyzShape = context->GetInputShape(1);
gert::Shape *distShape = context->GetOutputShape(0);
gert::Shape *idxShape = context->GetOutputShape(1);
const gert::RuntimeAttrs *attr = context->GetAttrs();
uint32_t batch;
uint32_t nPoint;
if ((xyzShape == nullptr) || (centerXyzShape == nullptr) || (distShape == nullptr) || (idxShape == nullptr) ||
(attr == nullptr)) {
return ge::GRAPH_FAILED;
}
if ((xyzShape->GetDimNum() != 3) || ((centerXyzShape->GetDimNum() != 3))) {
return ge::GRAPH_FAILED;
}
batch = centerXyzShape->GetDim(0);
nPoint = centerXyzShape->GetDim(1);
const int32_t k = *attr->GetAttrPointer<int32_t>(1);
distShape->SetDimNum(3);
distShape->SetDim(0, batch);
distShape->SetDim(1, nPoint);
distShape->SetDim(2, k);
idxShape->SetDimNum(3);
idxShape->SetDim(0, batch);
idxShape->SetDim(1, nPoint);
idxShape->SetDim(2, k);
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeForKnn(gert::InferDataTypeContext *context) {
context->SetOutputDataType(0, ge::DT_FLOAT);
context->SetOutputDataType(1, ge::DT_INT32);
return GRAPH_SUCCESS;
}
}
namespace ops {
class Knn : public OpDef {
public:
explicit Knn(const char *name) : OpDef(name) {
this->Input("xyz")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("center_xyz")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("is_from_knn").AttrType(REQUIRED).Bool();
this->Attr("k").AttrType(REQUIRED).Int();
this->Output("dist")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("idx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->SetInferShape(ge::InfershapeForKnn).SetInferDataType(ge::InferDataTypeForKnn);
this->AICore().SetTiling(optiling::TilingForKnn);
OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true);
this->AICore().AddConfig("ascend910b", aicore_config);
this->AICore().AddConfig("ascend910_93", aicore_config);
#if __DRIVING_HOST_AICORE__ == 310
this->AICore().AddConfig("ascend950");
#endif
}
};
OP_ADD(Knn);
}