#include "graph_softmax_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"

namespace {
constexpr float AVALIABLE_UB_RATIO = 0.8;
constexpr uint32_t SRC_IDX = 0; // number of edge
constexpr uint32_t OUTPUT_IDX = 0;
constexpr uint32_t EDGE_IDX = 0;
constexpr uint32_t FEATURE_IDX = 1;
constexpr uint32_t N_IDX = 0;
constexpr uint32_t SINGLE_LOOP_TASK =
    2360; // 198*1024*0.8/4/(8*2+1),  UB*AVALIABLE_UB_RATIO*(size_of_float/int)/(8*2+1)
} // some const express

namespace optiling {
static ge::graphStatus TilingForGraphSoftmax(gert::TilingContext *context) {
    if (context == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto platformInfoPtr = context->GetPlatformInfo();
    if (platformInfoPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto ascendplatformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr);
    auto aivNum = ascendplatformInfo.GetCoreNumAiv();
    context->SetBlockDim(aivNum);
    uint64_t ubSize;
    ascendplatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
    ubSize *= AVALIABLE_UB_RATIO;

    if (aivNum == 0 || ubSize == 0) {
        return ge::GRAPH_FAILED;
    }

    const gert::StorageShape *srcShape = context->GetInputShape(SRC_IDX); // [num_edge, num_feature]
    const gert::RuntimeAttrs *attr = context->GetAttrs();
    if (srcShape == nullptr || attr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto NPtr = attr->GetAttrPointer<int>(N_IDX);
    if (NPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    uint32_t N = *NPtr;
    uint32_t numEdge = srcShape->GetStorageShape().GetDim(EDGE_IDX);
    uint32_t numFeature = srcShape->GetStorageShape().GetDim(FEATURE_IDX);
    uint32_t totalTask = numEdge;
    uint32_t coreTask = (totalTask + aivNum - 1) / aivNum;
    uint32_t coreWorkspace = (N + aivNum - 1) / aivNum;
    uint32_t totalWorkspace = coreWorkspace * aivNum;
    uint32_t bigCoreCount = totalTask % aivNum == 0 ? aivNum : totalTask % aivNum;
    uint32_t singleLoopTaskCount = SINGLE_LOOP_TASK;

    GraphSoftmaxTilingData tilingData;

    tilingData.set_coreTask(coreTask);
    tilingData.set_coreWorkspace(coreWorkspace);
    tilingData.set_totalTask(totalTask);
    tilingData.set_totalWorkspace(totalWorkspace);
    tilingData.set_bigCoreCount(bigCoreCount);
    tilingData.set_singleLoopTaskCount(singleLoopTaskCount);

    if (context->GetRawTilingData() == nullptr) {
        return ge::GRAPH_FAILED;
    }
    tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());

    size_t systemWorkspaceSize = ascendplatformInfo.GetLibApiWorkSpaceSize();
    size_t usrWorkSpaceSize = 2 * totalWorkspace * numFeature * sizeof(float);
    size_t *currentWorkspace = context->GetWorkspaceSizes(1);
    if (currentWorkspace == nullptr) {
        return ge::GRAPH_FAILED;
    }
    currentWorkspace[0] = systemWorkspaceSize + usrWorkSpaceSize;
    return ge::GRAPH_SUCCESS;
}
}

namespace ge {
static ge::graphStatus InferShapeForGraphSoftmax(gert::InferShapeContext *context) {
    const gert::Shape *src = context->GetInputShape(SRC_IDX);
    gert::Shape *softmaxResult = context->GetOutputShape(OUTPUT_IDX);
    if (src == nullptr || softmaxResult == nullptr) {
        return ge::GRAPH_FAILED;
    }
    uint64_t numEdge = src->GetDim(EDGE_IDX);
    uint64_t numFeature = src->GetDim(FEATURE_IDX);

    *softmaxResult = {numEdge, numFeature};
    return GRAPH_SUCCESS;
}

static ge::graphStatus InferDataTypeForGraphSoftmax(gert::InferDataTypeContext *context) {
    const ge::DataType num_valid_dtype = context->GetInputDataType(SRC_IDX);
    context->SetOutputDataType(OUTPUT_IDX, num_valid_dtype);
    return GRAPH_SUCCESS;
}
}

namespace ops {
class GraphSoftmax : public OpDef {
  public:
    explicit GraphSoftmax(const char *name) : OpDef(name) {
        this->Input("src")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();

        this->Input("index")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();

        this->Output("softmax_result")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();

        this->Attr("N").Int();

        this->SetInferShape(ge::InferShapeForGraphSoftmax).SetInferDataType(ge::InferDataTypeForGraphSoftmax);
        this->AICore().SetTiling(optiling::TilingForGraphSoftmax);
        this->AICore().AddConfig("ascend910b");
        this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
        this->AICore().AddConfig("ascend950");
#endif
    }
};

OP_ADD(GraphSoftmax);
}