/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
 */

#include "hypot_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
#include <algorithm>

namespace optiling {
constexpr uint32_t SIZE_OF_FLOAT = 4;
constexpr uint32_t BLOCK_SIZE = 4096;
constexpr uint32_t BYTE_BLOCK = 32;
constexpr uint32_t ALIGN_NUM = BYTE_BLOCK / SIZE_OF_FLOAT;
constexpr uint32_t RESERVED_UB_SIZE = 16 * 1024;
constexpr uint32_t BUFFER_NUM = 2;
constexpr uint32_t QUEUE_NUM = 3;

static ge::graphStatus TilingFunc(gert::TilingContext *context) {
    if (context == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto platformInfo = context->GetPlatformInfo();
    if (platformInfo == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
    auto coreNum = ascendcPlatform.GetCoreNumAiv();
    uint64_t ubSizePlatform;
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatform);
    if (coreNum == 0) {
        return ge::GRAPH_FAILED;
    }
    if (context->GetInputTensor(0) == nullptr) {
        return ge::GRAPH_FAILED;
    }
    uint32_t totalLength = context->GetInputTensor(0)->GetShapeSize();

    uint32_t totalLengthAligned = ((totalLength + ALIGN_NUM - 1) / ALIGN_NUM) * ALIGN_NUM;
    uint32_t usedCoreNum = (totalLengthAligned - 1) / BLOCK_SIZE + 1;
    usedCoreNum = std::min(usedCoreNum, coreNum);
    if (usedCoreNum == 0) {
        return ge::GRAPH_FAILED;
    }
    context->SetBlockDim(usedCoreNum);

    uint32_t formerNum = (totalLengthAligned / ALIGN_NUM) % usedCoreNum;
    uint32_t tailNum = usedCoreNum - formerNum;

    uint32_t baseLength = totalLengthAligned / ALIGN_NUM / usedCoreNum;
    uint32_t formerLength = (baseLength + (formerNum ? 1 : 0)) * ALIGN_NUM;
    uint32_t tailLength = baseLength * ALIGN_NUM;

    uint32_t formerSlice =
        formerLength * QUEUE_NUM * SIZE_OF_FLOAT * BUFFER_NUM / (ubSizePlatform - RESERVED_UB_SIZE) + 1;
    uint32_t tailSlice = tailLength * QUEUE_NUM * SIZE_OF_FLOAT * BUFFER_NUM / (ubSizePlatform - RESERVED_UB_SIZE) + 1;

    uint32_t formerTileLength = (formerLength / formerSlice + ALIGN_NUM - 1) / ALIGN_NUM * ALIGN_NUM;
    uint32_t tailTileLength = (tailLength / tailSlice + ALIGN_NUM - 1) / ALIGN_NUM * ALIGN_NUM;
    uint32_t formerTileNum =
        (formerTileLength == 0) ? 0 : formerLength / formerTileLength + ((formerLength % formerTileLength) ? 1 : 0);
    uint32_t tailTileNum =
        (tailTileLength == 0) ? 0 : tailLength / tailTileLength + ((tailLength % tailTileLength) ? 1 : 0);
    uint32_t formerRemainTileLength = (formerTileNum == 0) ? 0 : formerLength - (formerTileNum - 1) * formerTileLength;
    uint32_t tailRemainTileLength = (tailTileNum == 0) ? 0 : tailLength - (tailTileNum - 1) * tailTileLength;

    HypotTilingData tiling;
    tiling.set_formerNum(formerNum);
    tiling.set_tailNum(tailNum);
    tiling.set_formerLength(static_cast<uint64_t>(formerLength));
    tiling.set_tailLength(tailLength);
    tiling.set_formerTileLength(formerTileLength);
    tiling.set_tailTileLength(tailTileLength);
    tiling.set_formerTileNum(formerTileNum);
    tiling.set_tailTileNum(tailTileNum);
    tiling.set_formerRemainTileLength(formerRemainTileLength);
    tiling.set_tailRemainTileLength(tailRemainTileLength);

    if (context->GetRawTilingData() == nullptr) {
        return ge::GRAPH_FAILED;
    }
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
    return ge::GRAPH_SUCCESS;
}
}

namespace ge {
static ge::graphStatus Infershape(gert::InferShapeContext *context) {
    const auto inputShape = context->GetInputShape(0);
    auto outputShape = context->GetOutputShape(0);
    if (inputShape == nullptr || outputShape == nullptr) {
        return ge::GRAPH_FAILED;
    }
    *outputShape = *inputShape;
    return GRAPH_SUCCESS;
}

static ge::graphStatus InferDtype(gert::InferDataTypeContext *context) {
    const auto out_dtype = context->GetInputDataType(0);
    context->SetOutputDataType(0, out_dtype);
    return GRAPH_SUCCESS;
}
}

namespace ops {
class Hypot : public OpDef {
  public:
    explicit Hypot(const char *name) : OpDef(name) {
        this->Input("x")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Input("y")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("out")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->SetInferShape(ge::Infershape).SetInferDataType(ge::InferDtype);
        this->AICore().SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend910b");
        this->AICore().AddConfig("ascend910_93");
#if __DRIVING_HOST_AICORE__ == 310
        this->AICore().AddConfig("ascend950");
#endif
    }
};

OP_ADD(Hypot);
}