#include "subm_sparse_conv3d_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"
using namespace ge;
using namespace std;
using namespace AscendC;
namespace optiling {
const uint32_t BLOCK_DIM = 8;
const uint32_t TILE_NUM = 8;
static int32_t GetCeilInt(int32_t value1, int32_t value2)
{
    if (value2 == 0) {
        return value1;
    }
    return static_cast<int32_t>((value1 + value2 - 1) / value2);
}

static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
    SubmSparseConv3dTilingData tiling;
    if (context == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto platformInfoptr = context->GetPlatformInfo();
    if (platformInfoptr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto ascendplatformInfo = platform_ascendc::PlatformAscendC(platformInfoptr);
    auto core_number = ascendplatformInfo.GetCoreNumAiv();
    uint32_t totalresult = context->GetInputTensor(0)->GetStorageShape().GetDim(0);
    auto feature_shape = context->GetInputTensor(0)->GetStorageShape();
    auto indices_shape = context->GetInputTensor(1)->GetStorageShape();
    auto attrsPtr = context->GetAttrs();
    if (attrsPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto kernel_size = attrsPtr->GetAttrPointer<gert::ContinuousVector>(0);
    auto kernel_size_data = reinterpret_cast<const int64_t*>(kernel_size->GetData());
    tiling.set_K0(kernel_size_data[0]);
    tiling.set_K1(kernel_size_data[1]);
    tiling.set_K2(kernel_size_data[2]);
    auto out_spatial_shape = attrsPtr->GetAttrPointer<gert::ContinuousVector>(2);
    auto out_spatial_shape_data = reinterpret_cast<const int64_t*>(out_spatial_shape->GetData());
    tiling.set_D(out_spatial_shape_data[0]);
    tiling.set_H(out_spatial_shape_data[1]);
    tiling.set_W(out_spatial_shape_data[2]);
    tiling.set_feature_map_size(out_spatial_shape_data[0] * out_spatial_shape_data[1]*out_spatial_shape_data[2]);
    auto out_channel = *(attrsPtr->GetAttrPointer<int32_t>(1));
    auto batch_size = *(attrsPtr->GetAttrPointer<int32_t>(3));
    int32_t core_data;
    int32_t core_used;
    int32_t core_last;
    core_data = GetCeilInt(totalresult, core_number);
    core_data = GetCeilInt(core_data, 64) * 64;
    core_used = GetCeilInt(totalresult, core_data);
    core_last = core_data;
    if (core_data == 0) {
        return ge::GRAPH_FAILED;
    }
    if (totalresult % core_data != 0) { core_last = totalresult % core_data;}
    uint64_t available_ub_size;
    ascendplatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, available_ub_size);
    int32_t number = 20;
    int32_t total_kernel = kernel_size_data[0] * kernel_size_data[1] * kernel_size_data[2];
    available_ub_size = (available_ub_size - 20*1024 - total_kernel*6*4 - feature_shape.GetDim(1)*4) / number;
    available_ub_size = GetCeilInt(available_ub_size, 64) * 64;
    context->SetBlockDim(core_used);
    tiling.set_core_data(core_data);
    tiling.set_core_used(core_used);
    tiling.set_copy_loop(core_data / available_ub_size);
    tiling.set_copy_tail(core_data % available_ub_size);
    tiling.set_last_copy_loop(core_last / available_ub_size);
    tiling.set_last_copy_tail(core_last % available_ub_size);
    tiling.set_inchannel(feature_shape.GetDim(1));
    tiling.set_outchannel(out_channel);
    tiling.set_indices_number(indices_shape.GetDim(1));
    tiling.set_available_ub_size(available_ub_size);
    tiling.set_batch_size(batch_size);
    if (context->GetRawTilingData() == nullptr) {
        return ge::GRAPH_FAILED;
    }
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
    size_t *currentWorkspace = context->GetWorkspaceSizes(1);
    if (currentWorkspace == nullptr) {
        return ge::GRAPH_FAILED;
    }
    currentWorkspace[0] = 1;
    return ge::GRAPH_SUCCESS;
}
}

namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext* context)
{
    auto attrsPtr = context->GetAttrs();
    if (attrsPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto kernel_size = attrsPtr->GetAttrPointer<gert::ContinuousVector>(0);
    auto kernel_size_data = reinterpret_cast<const int64_t*>(kernel_size->GetData());
    const gert::Shape* indices_shape = context->GetInputShape(1);
    if (indices_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    gert::Shape* y_shape = context->GetOutputShape(0);
    if (y_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    gert::Shape* indices_out_shape = context->GetOutputShape(1);
    if (indices_out_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    gert::Shape* indices_pair_shape = context->GetOutputShape(2);
    if (indices_pair_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto kernel_num = kernel_size_data[0] * kernel_size_data[1] * kernel_size_data[2];
    auto output_num = indices_shape->GetDim(0) * kernel_num;
    auto batch_size = *(attrsPtr->GetAttrPointer<int32_t>(3));
    auto out_channel = *(attrsPtr->GetAttrPointer<int32_t>(1));
    y_shape->SetDimNum(0);
    y_shape->AppendDim(output_num);
    y_shape->AppendDim(out_channel);
    indices_out_shape->SetDimNum(0);
    indices_out_shape->AppendDim(output_num);
    indices_pair_shape->SetDimNum(0);
    indices_pair_shape->AppendDim(output_num);
    indices_pair_shape->AppendDim(indices_shape->GetDim(1));
    return GRAPH_SUCCESS;
}
}


namespace ops {
class SubmSparseConv3d : public OpDef {
public:
    explicit SubmSparseConv3d(const char* name) : OpDef(name)
    {
        this->Input("feature")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("indices")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("weight")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("temp")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Output("feature_out")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("indices_offset")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("indices_pair")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Attr("kernel_size")
            .AttrType(REQUIRED)
            .ListInt();
        this->Attr("out_channel")
            .AttrType(REQUIRED)
            .Int();
        this->Attr("outSpatialShape")
            .AttrType(REQUIRED)
            .ListInt();
        this->Attr("batch_size")
            .AttrType(REQUIRED)
            .Int();

        this->SetInferShape(ge::InferShape);

        this->AICore()
            .SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend910b");
        this->AICore().AddConfig("ascend910_93");
    }
};

OP_ADD(SubmSparseConv3d);
}