/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/**
 * NOTE: Portions of this code were AI-generated and have been
 * technically reviewed for functional accuracy and security
 */
/*!
 * \file ndtri_tiling.cpp
 * \brief Ndtri Tiling 实现(arch35 / Ascend950)
 */

#include "register/op_def_registry.h"
#include "op_common/log/log.h"
#include "op_common/op_host/util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "../op_kernel/ndtri_tiling_data.h"
#include "../op_kernel/ndtri_tiling_key.h"

namespace optiling {

constexpr uint32_t WS_USER_SIZE = 0U;  // 算子自身不需要额外 workspace
static constexpr size_t IDX_SELF = 0;
constexpr int64_t TYPE_SIZE_FP32 = 4;
constexpr int64_t TYPE_SIZE_FP16_BF16 = 2;
constexpr int64_t RESERVED_UB = 48 * 1024;
constexpr int64_t BYTE_PER_ELEM = 64;
constexpr int64_t TILE_ALIGN = 256;

static const gert::Shape K_VEC_1_SHAPE = {1};

static inline const gert::Shape EnsureNotScalar(const gert::Shape& in_shape)
{
    if (in_shape.GetDimNum() == 0) {
        return K_VEC_1_SHAPE;
    }
    return in_shape;
}

// 平台信息
static ge::graphStatus GetPlatformInfo(
    gert::TilingContext* context, uint64_t& ubSize, int64_t& coreNum,
    uint32_t& sysWorkspaceSize)
{
    fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
    OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    coreNum = ascendcPlatform.GetCoreNumAiv();
    OP_CHECK_IF(coreNum == 0, OP_LOGE(context, "coreNum is 0"),
                return ge::GRAPH_FAILED);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
    OP_CHECK_IF(ubSize == 0, OP_LOGE(context, "ubSize is 0"),
                return ge::GRAPH_FAILED);
    // 系统默认 workspace(框架自身保留),不要硬编码
    sysWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize();
    return ge::GRAPH_SUCCESS;
}

// dtype 校验
static ge::graphStatus CheckDtype(gert::TilingContext* context, ge::DataType& dtype)
{
    auto selfDesc = context->GetInputDesc(IDX_SELF);
    OP_CHECK_NULL_WITH_CONTEXT(context, selfDesc);
    dtype = selfDesc->GetDataType();
    const std::set<ge::DataType> supported = {
        ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
    OP_CHECK_IF(supported.count(dtype) == 0,
                OP_LOGE(context, "Ndtri: unsupported dtype %d",
                        static_cast<int>(dtype)),
                return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// 提取 totalNum
static ge::graphStatus GetTotalNum(gert::TilingContext* context, int64_t& totalNum)
{
    auto selfShapePtr = context->GetInputShape(IDX_SELF);
    OP_CHECK_NULL_WITH_CONTEXT(context, selfShapePtr);
    auto selfShape = EnsureNotScalar(selfShapePtr->GetStorageShape());
    totalNum = selfShape.GetShapeSize();
    OP_CHECK_IF(totalNum <= 0,
                OP_LOGE(context, "Ndtri: totalNum must > 0, got %ld", totalNum),
                return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// 多核 + UB 切分
static ge::graphStatus DoTiling(
    gert::TilingContext* context, ge::DataType dtype, int64_t totalNum,
    uint64_t ubSize, int64_t coreNum,
    NdtriTilingData* tiling, int64_t& usedCoreNum, int64_t& alignElem)
{
    int64_t ubBlockSize = Ops::Base::GetUbBlockSize(context);
    int64_t typeSize = (dtype == ge::DT_FLOAT) ? TYPE_SIZE_FP32 : TYPE_SIZE_FP16_BF16;
    OP_CHECK_IF(typeSize <= 0, OP_LOGE(context, "typeSize<=0"),
                return ge::GRAPH_FAILED);
    alignElem = ubBlockSize / typeSize;
    OP_CHECK_IF(alignElem <= 0, OP_LOGE(context, "alignElem<=0"),
                return ge::GRAPH_FAILED);

    if (totalNum < alignElem) {
        tiling->blockFactor = totalNum;
        usedCoreNum = 1;
    } else {
        int64_t perCoreRaw = Ops::Base::CeilDiv(totalNum, coreNum);
        tiling->blockFactor = Ops::Base::CeilAlign(perCoreRaw, alignElem);
        usedCoreNum = Ops::Base::CeilDiv(totalNum, tiling->blockFactor);
    }
    OP_CHECK_IF(usedCoreNum == 0, OP_LOGE(context, "usedCoreNum is 0"),
                return ge::GRAPH_FAILED);

    int64_t availableUb = static_cast<int64_t>(ubSize) - RESERVED_UB;
    OP_CHECK_IF(availableUb <= 0, OP_LOGE(context, "availableUb<=0"),
                return ge::GRAPH_FAILED);
    int64_t tileElem = availableUb / BYTE_PER_ELEM;
    tileElem = Ops::Base::FloorAlign(tileElem, TILE_ALIGN);
    if (tileElem < alignElem) {
        tileElem = alignElem;
    }
    tiling->ubFactor = tileElem;
    return ge::GRAPH_SUCCESS;
}

// TilingKey 派发
static void DispatchTilingKey(
    gert::TilingContext* context, ge::DataType dtype, int64_t totalNum, int64_t alignElem)
{
    uint32_t dtypeKey;
    if (dtype == ge::DT_FLOAT) {
        dtypeKey = static_cast<uint32_t>(C_DT_FLOAT);
    } else if (dtype == ge::DT_FLOAT16) {
        dtypeKey = static_cast<uint32_t>(C_DT_FLOAT16);
    } else {
        dtypeKey = static_cast<uint32_t>(C_DT_BF16);
    }
    // 防御性兜底:alignElem 由 DoTiling 通过 ubBlockSize/typeSize 计算,
    // typeSize 为常量 2 或 4 不会为 0;此处显式检查避免静态分析误报除零。
    uint32_t isAlign = (alignElem > 0 && totalNum % alignElem == 0) ? 1U : 0U;
    ASCENDC_TPL_SEL_PARAM(context, dtypeKey, isAlign);
}

// Tiling 入口
static ge::graphStatus NdtriTilingFunc(gert::TilingContext* context)
{
    uint64_t ubSize = 0;
    int64_t coreNum = 0;
    uint32_t sysWorkspaceSize = 0;
    OP_CHECK_IF(GetPlatformInfo(context, ubSize, coreNum, sysWorkspaceSize) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "GetPlatformInfo error"),
                return ge::GRAPH_FAILED);

    ge::DataType dtype;
    OP_CHECK_IF(CheckDtype(context, dtype) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "CheckDtype error"),
                return ge::GRAPH_FAILED);

    int64_t totalNum = 0;
    OP_CHECK_IF(GetTotalNum(context, totalNum) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "GetTotalNum error"),
                return ge::GRAPH_FAILED);

    size_t* currentWorkspace = context->GetWorkspaceSizes(1);
    OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
    currentWorkspace[0] = WS_USER_SIZE + sysWorkspaceSize;

    NdtriTilingData* tiling = context->GetTilingData<NdtriTilingData>();
    OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
    OP_CHECK_IF(memset_s(tiling, sizeof(NdtriTilingData), 0, sizeof(NdtriTilingData)) != EOK,
                OP_LOGE(context, "set tiling data error"),
                return ge::GRAPH_FAILED);
    tiling->totalNum = totalNum;

    int64_t usedCoreNum = 0;
    int64_t alignElem = 0;
    OP_CHECK_IF(DoTiling(context, dtype, totalNum, ubSize, coreNum,
                         tiling, usedCoreNum, alignElem) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "DoTiling error"),
                return ge::GRAPH_FAILED);

    context->SetBlockDim(usedCoreNum);
    DispatchTilingKey(context, dtype, totalNum, alignElem);
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus TilingParseForNdtri(
    [[maybe_unused]] gert::TilingParseContext* context)
{
    return ge::GRAPH_SUCCESS;
}

struct NdtriCompileInfo {};

IMPL_OP_OPTILING(Ndtri)
    .Tiling(NdtriTilingFunc)
    .TilingParse<NdtriCompileInfo>(TilingParseForNdtri);

} // namespace optiling