/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 *
 * Generated By CANNBot
 */

/*!
 * \file ndtri_tiling_arch35.cpp
 * \brief Ndtri Tiling 实现(arch35 / Ascend950)
 */

#include "register/op_def_registry.h"
#include "log/log.h"
#include "util/math_util.h"
#include "util/platform_util.h"
#include "../../op_kernel/arch35/ndtri_tiling_data.h"
#include "../../op_kernel/arch35/ndtri_tiling_key.h"

namespace optiling {

constexpr uint32_t WS_USER_SIZE = 0U;  // 算子自身不需要额外 workspace
static constexpr size_t IDX_SELF = 0;
constexpr int64_t TYPE_SIZE_FP32 = 4;
constexpr int64_t TYPE_SIZE_FP16_BF16 = 2;
constexpr int64_t RESERVED_UB = 48 * 1024;
// 按 Kernel 实际 UB 占用估算:
//   in/out (DB):    2 * 2 * sizeof(T) → FP32 16, FP16/BF16 8
//   pBuf + yBuf:    2 * 4 = 8
//   tmpBuf0..10:    11 * 4 = 44
//   maskBuf0..3:    ~0.5 (按 bit)
// FP32 合计 ≈ 68.5 → 取 80(保留余量)
// FP16/BF16 合计 ≈ 60.5 → 取 64
constexpr int64_t BYTE_PER_ELEM_FP32 = 80;
constexpr int64_t BYTE_PER_ELEM_FP16_BF16 = 64;
constexpr int64_t TILE_ALIGN = 256;

static const gert::Shape K_VEC_1_SHAPE = {1};

static inline const gert::Shape EnsureNotScalar(const gert::Shape& in_shape)
{
    if (in_shape.GetDimNum() == 0) {
        return K_VEC_1_SHAPE;
    }
    return in_shape;
}

// 平台信息
static ge::graphStatus GetPlatformInfo(
    gert::TilingContext* context, uint64_t& ubSize, int64_t& coreNum,
    uint32_t& sysWorkspaceSize)
{
    fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
    OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    coreNum = ascendcPlatform.GetCoreNumAiv();
    OP_CHECK_IF(coreNum == 0, OP_LOGE(context, "coreNum is 0"),
                return ge::GRAPH_FAILED);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
    OP_CHECK_IF(ubSize == 0, OP_LOGE(context, "ubSize is 0"),
                return ge::GRAPH_FAILED);
    sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
    return ge::GRAPH_SUCCESS;
}

// dtype 校验
static ge::graphStatus CheckDtype(gert::TilingContext* context, ge::DataType& dtype)
{
    auto selfDesc = context->GetInputDesc(IDX_SELF);
    OP_CHECK_NULL_WITH_CONTEXT(context, selfDesc);
    dtype = selfDesc->GetDataType();
    const std::set<ge::DataType> supported = {
        ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
    OP_CHECK_IF(supported.count(dtype) == 0,
                OP_LOGE(context, "Ndtri: unsupported dtype %d",
                        static_cast<int>(dtype)),
                return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// 提取 totalNum
static ge::graphStatus GetTotalNum(gert::TilingContext* context, int64_t& totalNum)
{
    auto selfShapePtr = context->GetInputShape(IDX_SELF);
    OP_CHECK_NULL_WITH_CONTEXT(context, selfShapePtr);
    auto selfShape = EnsureNotScalar(selfShapePtr->GetStorageShape());
    totalNum = selfShape.GetShapeSize();
    OP_CHECK_IF(totalNum <= 0,
                OP_LOGE(context, "Ndtri: totalNum must > 0, got %ld", totalNum),
                return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// 多核 + UB 切分
static ge::graphStatus DoTiling(
    gert::TilingContext* context, ge::DataType dtype, int64_t totalNum,
    uint64_t ubSize, int64_t coreNum,
    NdtriTilingData* tiling, int64_t& usedCoreNum, int64_t& alignElem)
{
    int64_t ubBlockSize = Ops::Base::GetUbBlockSize(context);
    int64_t typeSize = (dtype == ge::DT_FLOAT) ? TYPE_SIZE_FP32 : TYPE_SIZE_FP16_BF16;
    OP_CHECK_IF(typeSize <= 0, OP_LOGE(context, "typeSize<=0"),
                return ge::GRAPH_FAILED);
    alignElem = ubBlockSize / typeSize;
    OP_CHECK_IF(alignElem <= 0, OP_LOGE(context, "alignElem<=0"),
                return ge::GRAPH_FAILED);

    if (totalNum < alignElem) {
        tiling->blockFactor = totalNum;
        usedCoreNum = 1;
    } else {
        int64_t perCoreRaw = Ops::Base::CeilDiv(totalNum, coreNum);
        tiling->blockFactor = Ops::Base::CeilAlign(perCoreRaw, alignElem);
        usedCoreNum = Ops::Base::CeilDiv(totalNum, tiling->blockFactor);
    }
    OP_CHECK_IF(usedCoreNum == 0, OP_LOGE(context, "usedCoreNum is 0"),
                return ge::GRAPH_FAILED);

    int64_t availableUb = static_cast<int64_t>(ubSize) - RESERVED_UB;
    OP_CHECK_IF(availableUb <= 0, OP_LOGE(context, "availableUb<=0"),
                return ge::GRAPH_FAILED);
    int64_t bytePerElem = (dtype == ge::DT_FLOAT) ? BYTE_PER_ELEM_FP32 : BYTE_PER_ELEM_FP16_BF16;
    int64_t tileElem = availableUb / bytePerElem;
    tileElem = Ops::Base::FloorAlign(tileElem, TILE_ALIGN);
    if (tileElem < alignElem) {
        tileElem = alignElem;
    }
    tiling->ubFactor = tileElem;
    OP_LOGD(context, "Ndtri UB plan: ubFactor=%ld, bytePerElem=%ld, estUbUsage=%ldB, reserved=%ldB, available=%ldB",
            static_cast<long>(tileElem),
            static_cast<long>(bytePerElem),
            static_cast<long>(tileElem * bytePerElem),
            static_cast<long>(RESERVED_UB),
            static_cast<long>(availableUb));
    return ge::GRAPH_SUCCESS;
}

// TilingKey 派发
static void DispatchTilingKey(
    gert::TilingContext* context, ge::DataType dtype, int64_t totalNum, int64_t alignElem)
{
    uint32_t dtypeKey;
    if (dtype == ge::DT_FLOAT) {
        dtypeKey = static_cast<uint32_t>(C_DT_FLOAT);
    } else if (dtype == ge::DT_FLOAT16) {
        dtypeKey = static_cast<uint32_t>(C_DT_FLOAT16);
    } else {
        dtypeKey = static_cast<uint32_t>(C_DT_BF16);
    }
    uint32_t isAlign = (alignElem > 0 && totalNum > 0 && totalNum % alignElem == 0) ? 1U : 0U;
    ASCENDC_TPL_SEL_PARAM(context, dtypeKey, isAlign);
}

// Tiling 入口
static ge::graphStatus NdtriTilingFunc(gert::TilingContext* context)
{
    OP_LOGD(context->GetNodeName(), "Enter NdtriTilingFunc");
    uint64_t ubSize = 0;
    int64_t coreNum = 0;
    uint32_t sysWorkspaceSize = 0;
    OP_CHECK_IF(GetPlatformInfo(context, ubSize, coreNum, sysWorkspaceSize) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "GetPlatformInfo error"),
                return ge::GRAPH_FAILED);

    ge::DataType dtype;
    OP_CHECK_IF(CheckDtype(context, dtype) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "CheckDtype error"),
                return ge::GRAPH_FAILED);

    int64_t totalNum = 0;
    OP_CHECK_IF(GetTotalNum(context, totalNum) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "GetTotalNum error"),
                return ge::GRAPH_FAILED);

    size_t* currentWorkspace = context->GetWorkspaceSizes(1);
    OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
    currentWorkspace[0] = WS_USER_SIZE + sysWorkspaceSize;

    NdtriTilingData* tiling = context->GetTilingData<NdtriTilingData>();
    OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
    OP_CHECK_IF(memset_s(tiling, sizeof(NdtriTilingData), 0, sizeof(NdtriTilingData)) != EOK,
                OP_LOGE(context, "set tiling data error"),
                return ge::GRAPH_FAILED);
    tiling->totalNum = totalNum;

    int64_t usedCoreNum = 0;
    int64_t alignElem = 0;
    OP_CHECK_IF(DoTiling(context, dtype, totalNum, ubSize, coreNum,
                         tiling, usedCoreNum, alignElem) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "DoTiling error"),
                return ge::GRAPH_FAILED);

    context->SetBlockDim(usedCoreNum);
    DispatchTilingKey(context, dtype, totalNum, alignElem);
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus TilingParseForNdtri(
    [[maybe_unused]] gert::TilingParseContext* context)
{
    return ge::GRAPH_SUCCESS;
}

struct NdtriCompileInfo {};

IMPL_OP_OPTILING(Ndtri)
    .Tiling(NdtriTilingFunc)
    .TilingParse<NdtriCompileInfo>(TilingParseForNdtri);

} // namespace optiling