/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 */
#include "points_in_box_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/tiling_api.h"
#include "tiling/platform/platform_ascendc.h"
using namespace ge;
using namespace std;
using namespace AscendC;
namespace optiling {
const uint32_t BLOCK_DIM = 8;
const uint32_t TILE_NUM = 8;
static int32_t GetCeilInt(int32_t value1, int32_t value2) {
    if (value2 == 0) {
        return value1;
    }
    return static_cast<int32_t>((value1 + value2 - 1) / value2);
}

static ge::graphStatus TilingFunc(gert::TilingContext *context) {
    if (context == nullptr) {
        return ge::GRAPH_FAILED;
    }
    PointsInBoxTilingData tiling;
    auto platformInfoptr = context->GetPlatformInfo();
    if (platformInfoptr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto ascendplatformInfo = platform_ascendc::PlatformAscendC(platformInfoptr);
    auto core_number = ascendplatformInfo.GetCoreNumAiv();
    if (context->GetInputTensor(0) == nullptr || context->GetInputTensor(1) == nullptr) {
        return ge::GRAPH_FAILED;
    }
    uint32_t totalresult = context->GetInputTensor(1)->GetShapeSize() / 3;
    auto boxes_shape = context->GetInputTensor(0)->GetStorageShape();
    auto points_shape = context->GetInputTensor(1)->GetStorageShape();
    int32_t core_data;
    int32_t core_used;
    int32_t core_last;
    core_data = GetCeilInt(totalresult, core_number);
    core_data = GetCeilInt(core_data, 64) * 64;
    core_used = GetCeilInt(totalresult, core_data);
    core_last = core_data;
    if (core_data == 0) {
        return ge::GRAPH_FAILED;
    }
    if (totalresult % core_data != 0) {
        core_last = totalresult % core_data;
    }
    uint64_t available_ub_size;
    ascendplatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, available_ub_size);
    available_ub_size = (available_ub_size - 20 * 1024) / 50 / 4;
    available_ub_size = GetCeilInt(available_ub_size, 32) * 32;
    if (available_ub_size == 0) {
        return ge::GRAPH_FAILED;
    }
    context->SetBlockDim(core_used);
    tiling.set_core_data(core_data);
    tiling.set_core_used(core_used);
    tiling.set_copy_loop(core_data / available_ub_size);
    tiling.set_copy_tail(core_data % available_ub_size);
    tiling.set_last_copy_loop(core_last / available_ub_size);
    tiling.set_last_copy_tail(core_last % available_ub_size);
    tiling.set_batch(boxes_shape.GetDim(0));
    tiling.set_npoints(points_shape.GetDim(1));
    tiling.set_box_number(boxes_shape.GetDim(2));
    tiling.set_available_ub_size(available_ub_size);
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
    size_t *currentWorkspace = context->GetWorkspaceSizes(1);
    if (currentWorkspace == nullptr) {
        return ge::GRAPH_FAILED;
    }
    currentWorkspace[0] = 0;
    return ge::GRAPH_SUCCESS;
}
}

namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
    const gert::Shape *pts_shape = context->GetInputShape(1);
    if (pts_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    gert::Shape *y_shape = context->GetOutputShape(0);
    if (y_shape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    y_shape->SetDimNum(0);
    y_shape->AppendDim(pts_shape->GetDim(0));
    y_shape->AppendDim(pts_shape->GetDim(1));
    return GRAPH_SUCCESS;
}

static ge::graphStatus PointsInBoxInferDataType(gert::InferDataTypeContext *context) {
    context->SetOutputDataType(0, ge::DT_INT32);
    return GRAPH_SUCCESS;
}
}

namespace ops {
class PointsInBox : public OpDef {
  public:
    explicit PointsInBox(const char *name) : OpDef(name) {
        this->Input("boxes")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Input("pts")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("boxes_idx_of_points")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});

        this->SetInferShape(ge::InferShape).SetInferDataType(ge::PointsInBoxInferDataType);

        this->AICore().SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend910b");
        this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
        this->AICore().AddConfig("ascend950");
#endif
    }
};

OP_ADD(PointsInBox);
}