#include "ge/utils.h"
#include "register/op_def_registry.h"
#include "sparse_inverse_conv3d_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
using namespace ge;
namespace {
const uint32_t FEATURE_IDX = 0;
const uint32_t ORIGIN_INDICES_IDX = 1;
const uint32_t OUTPUT_FEATURE_IDX = 0;
const uint32_t INPUT_INDICES_OFFSET_IDX = 1;
const uint32_t ATTR_KERNELS_IDX = 0;
const uint32_t ATTR_IN_CHANNELS_IDX = 1;
const uint32_t TOTAL_TASK_DIM_IDX = 0;
const uint32_t KERNEL_SIZE_IDX_0 = 0;
const uint32_t KERNEL_SIZE_IDX_1 = 1;
const uint32_t KERNEL_SIZE_IDX_2 = 2;
const uint32_t KERNEL_INCHANNEL_IDX = 3;
const uint32_t NUM_BUFFER = 2;
};
namespace optiling {
static ge::graphStatus TilingForSparseInverseConv3d(gert::TilingContext* context)
{
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
auto platformInfoptr = context->GetPlatformInfo();
if (platformInfoptr == nullptr) {
return ge::GRAPH_FAILED;
}
auto ascendplatformInfo = platform_ascendc::PlatformAscendC(platformInfoptr);
uint32_t coreNum = ascendplatformInfo.GetCoreNumAiv();
auto featureShapePtr = context->GetInputTensor(FEATURE_IDX);
auto attrsPtr = context->GetAttrs();
auto kernelSizePtr = attrsPtr->GetAttrPointer<gert::ContinuousVector>(ATTR_KERNELS_IDX);
auto inChannelsPtr = attrsPtr->GetAttrPointer<int32_t>(ATTR_IN_CHANNELS_IDX);
if (!featureShapePtr || !attrsPtr || !kernelSizePtr || !inChannelsPtr) {
return ge::GRAPH_FAILED;
}
auto kernelSizeArr = reinterpret_cast<const int64_t*>(kernelSizePtr->GetData());
int32_t kernelSize =
kernelSizeArr[KERNEL_SIZE_IDX_0] * kernelSizeArr[KERNEL_SIZE_IDX_1] * kernelSizeArr[KERNEL_SIZE_IDX_2];
int32_t inChannel = *inChannelsPtr;
uint64_t availableUbSize;
ascendplatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, availableUbSize);
uint32_t reserveUbSize = 1024;
uint32_t indicesDataTypeSize = 4;
uint32_t featureDataTypeSize = 4;
uint32_t moveLen = (availableUbSize - reserveUbSize) /
(indicesDataTypeSize + indicesDataTypeSize * kernelSize + inChannel * featureDataTypeSize) /
NUM_BUFFER;
uint32_t totalTaskCount = featureShapePtr->GetStorageShape().GetDim(TOTAL_TASK_DIM_IDX);
uint32_t vectorCoreTask = Ceil(totalTaskCount, coreNum);
uint32_t usedVectorCoreNum = Ceil(totalTaskCount, vectorCoreTask);
uint32_t vectorLastCoreTask = Tail(totalTaskCount, vectorCoreTask);
uint32_t coreRepeatTimes = Ceil(vectorCoreTask, moveLen);
uint32_t lastCoreRepeatTimes = Ceil(vectorLastCoreTask, moveLen);
uint32_t coreMoveLenTail = Tail(vectorCoreTask, moveLen);
uint32_t lastCoreMoveLenTail = Tail(vectorLastCoreTask, moveLen);
SparseInverseConv3dTilingData tilingData;
context->SetBlockDim(usedVectorCoreNum);
tilingData.set_inChannel(inChannel);
tilingData.set_kernelSize(kernelSize);
tilingData.set_vectorCoreTask(vectorCoreTask);
tilingData.set_vectorLastCoreTask(vectorLastCoreTask);
tilingData.set_moveLen(moveLen);
tilingData.set_coreRepeatTimes(coreRepeatTimes);
tilingData.set_coreMoveLenTail(coreMoveLenTail);
tilingData.set_lastCoreRepeatTimes(lastCoreRepeatTimes);
tilingData.set_lastCoreMoveLenTail(lastCoreMoveLenTail);
if (context->GetRawTilingData() == nullptr) {
return ge::GRAPH_FAILED;
}
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
size_t* currentWorkspace = context->GetWorkspaceSizes(1);
if (currentWorkspace == nullptr) {
return ge::GRAPH_FAILED;
}
currentWorkspace[0] = 1;
return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShapeForSparseInverseConv3d(gert::InferShapeContext* context)
{
const gert::Shape* indicesShape = context->GetInputShape(ORIGIN_INDICES_IDX);
if (indicesShape == nullptr) {
return ge::GRAPH_FAILED;
}
auto attrsPtr = context->GetAttrs();
if (attrsPtr == nullptr) {
return ge::GRAPH_FAILED;
}
auto totalTaskCount = indicesShape->GetDim(TOTAL_TASK_DIM_IDX);
auto kernelSizePtr = attrsPtr->GetAttrPointer<gert::ContinuousVector>(ATTR_KERNELS_IDX);
auto kernelSizeData = reinterpret_cast<const int64_t*>(kernelSizePtr->GetData());
uint32_t kernelD = kernelSizeData[KERNEL_SIZE_IDX_0];
uint32_t kernelH = kernelSizeData[KERNEL_SIZE_IDX_1];
uint32_t kernelW = kernelSizeData[KERNEL_SIZE_IDX_2];
uint32_t kernelSize = kernelD * kernelH * kernelW;
auto inChannelsPtr = attrsPtr->GetAttrPointer<int32_t>(ATTR_IN_CHANNELS_IDX);
gert::Shape* outFeatureShape = context->GetOutputShape(OUTPUT_FEATURE_IDX);
if (outFeatureShape == nullptr) {
return ge::GRAPH_FAILED;
}
outFeatureShape->SetDimNum(0);
outFeatureShape->AppendDim(totalTaskCount);
outFeatureShape->AppendDim(*inChannelsPtr * kernelSize);
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDtypeForSparseInverseConv3d(gert::InferDataTypeContext* context)
{
const ge::DataType feature_dtype = context->GetInputDataType(FEATURE_IDX);
context->SetOutputDataType(0, feature_dtype);
return GRAPH_SUCCESS;
}
}
namespace ops {
class SparseInverseConv3d : public OpDef {
public:
explicit SparseInverseConv3d(const char* name) : OpDef(name)
{
this->Input("features")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("unique_indices_offset")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Input("sorted_idx_to_former_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND})
.AutoContiguous();
this->Output("output_img2col")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Attr("kernel_size").AttrType(REQUIRED).ListInt();
this->Attr("in_channels").AttrType(REQUIRED).Int();
this->SetInferShape(ge::InferShapeForSparseInverseConv3d)
.SetInferDataType(ge::InferDtypeForSparseInverseConv3d);
this->AICore().SetTiling(optiling::TilingForSparseInverseConv3d);
this->AICore().AddConfig("ascend910b");
this->AICore().AddConfig("ascend910_93");
}
};
OP_ADD(SparseInverseConv3d);
}