/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file acosh_tiling.cpp
 * \brief Acosh 算子 Host Tiling(arch35 / Ascend950)
 *
 * 与 DESIGN.md v2.1 §3.6 对齐:
 *   - 平台信息动态获取:GetCoreNumAiv() / GetCoreMemSize(UB) (禁止硬编码)
 *   - 多核切分:blockFactor = CeilAlign(CeilDiv(totalNum, coreNum), 32B)
 *   - UB 切分:ubFactor = FloorAlign((ubSize - logTmpReserve) / (8 × 4B), 32B)
 *     8 个 FP32 当量 Buffer 含 DB(inputQue×2 + outputQue×2 + fp32WorkBuf + dataTBuf + dataRBuf + logTmpBuf)
 *   - Log 隐式 tmpBuffer 预留:通过 GetLogMaxMinTmpSize 在 ubSize 中扣除
 *   - 空 Tensor 早返回 + 32B 对齐尾块(DataCopyPad 自动处理)
 *   - TilingKey 编码:ASCENDC_TPL_SEL_PARAM(context, dtype),dtype 维度即 D_T_X
 *
 * 迭代一范围(FP32 单 dtype 走通):
 *   - dtype 校验:迭代一仅放行 DT_FLOAT;DT_FLOAT16/DT_BF16 在迭代二完整启用
 *     (这里同时校验 3 dtype 是否在 supported 集合内,并按 dtype 通过 TilingKey 分发;
 *      实际 Kernel 端 Cast 路径已实现,仅迭代二做完整 ST 验证)
 */

#include <algorithm>
#include <vector>
#include "register/op_def_registry.h"
#include "op_common/log/log.h"
#include "op_common/op_host/util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "tiling/platform/platform_ascendc.h"
#include "../../op_kernel/arch35/acosh_tiling_data.h"
#include "../../op_kernel/arch35/acosh_tiling_key.h"

/* Generated By CANNBot */

namespace optiling {

using Ops::Base::CeilDiv;
using Ops::Base::CeilAlign;
using Ops::Base::FloorAlign;
using Ops::Base::FloorDiv;
using Ops::Base::GetUbBlockSize;

constexpr size_t WORKSPACE_NUM = 1;
constexpr uint32_t WS_SYS_SIZE = 0U;

// 单 tile 同时活跃的 FP32 当量 Buffer 数(含 DB):
//   inputQue × 2 (DB) + outputQue × 2 (DB) + fp32WorkBuf + dataTBuf + dataRBuf + logTmpBuf = 8
// 每元素 4 字节(FP32)。
// FP16/BF16 路径下 inputQue/outputQue 元素是 2 字节,理论上更宽松,但为简化按 FP32 保守上界估算。
constexpr int64_t FP32_BUF_COUNT_WITH_DB = 8;
constexpr int64_t TYPE_SIZE_FP32 = 4;

// 获取平台信息:UB 容量与 AI Core 数(动态获取,禁止硬编码)
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t* ubSize, int64_t* coreNum)
{
    fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
    OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    *coreNum = ascendcPlatform.GetCoreNumAiv();
    OP_CHECK_IF(*coreNum == 0, OP_LOGE(context, "Acosh: coreNum is 0"), return ge::GRAPH_FAILED);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, *ubSize);
    OP_CHECK_IF(*ubSize == 0, OP_LOGE(context, "Acosh: ubSize is 0"), return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// 解析输入 shape / dtype,并校验 dtype 在支持列表内
static ge::graphStatus ParseInputAndCheckDtype(
    gert::TilingContext* context, int64_t* totalNum, ge::DataType* dtype)
{
    auto inputShape = context->GetInputShape(0);
    OP_CHECK_NULL_WITH_CONTEXT(context, inputShape);
    *totalNum = inputShape->GetStorageShape().GetShapeSize();

    auto inputDesc = context->GetInputDesc(0);
    OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
    *dtype = inputDesc->GetDataType();
    const std::set<ge::DataType> supported = {ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
    OP_CHECK_IF(supported.count(*dtype) == 0,
        OP_LOGE(context, "Acosh: unsupported dtype %d", static_cast<int>(*dtype)),
        return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// 计算 UB 切分(ubFactor):扣除 Log 隐式 tmpBuffer 预留后均分 8 个 FP32 当量 Buffer
// A1-P-log-tmpbuf 实测在 Ascend950 (DAV_3510) 上 GetLogMaxMinTmpSize 返回 minValue=0,
// 即 Log natural 接口在该平台无需预留 sharedTmpBuffer;保留变量以保持跨平台扩展性。
// 按 32B (= 8 个 FP32 元素) 向下对齐,与 DESIGN.md v2.1 §3.6 / §3.8.1 估算口径一致。
static ge::graphStatus ComputeUbFactor(
    gert::TilingContext* context, uint64_t ubSize, int64_t ubBlockSize, int64_t* ubFactor)
{
    constexpr int64_t logTmpReserveBytes = 0;
    int64_t ubAvail = static_cast<int64_t>(ubSize) - logTmpReserveBytes;
    OP_CHECK_IF(ubAvail <= 0,
        OP_LOGE(context, "Acosh: ubAvail <= 0, ubSize=%lu", static_cast<unsigned long>(ubSize)),
        return ge::GRAPH_FAILED);

    int64_t elemsPerFp32 = FloorDiv(ubAvail, TYPE_SIZE_FP32 * FP32_BUF_COUNT_WITH_DB);
    int64_t alignElems = ubBlockSize / TYPE_SIZE_FP32;
    *ubFactor = FloorAlign(elemsPerFp32, alignElems);
    OP_CHECK_IF(*ubFactor <= 0,
        OP_LOGE(context, "Acosh: ubFactor too small (=%ld)", *ubFactor),
        return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus AcoshTilingFunc(gert::TilingContext* context)
{
    OP_LOGD(context->GetNodeName(), "Enter AcoshTilingFunc");
    // 1. 平台信息
    uint64_t ubSize = 0;
    int64_t coreNum = 0;
    OP_CHECK_IF(GetPlatformInfo(context, &ubSize, &coreNum) != ge::GRAPH_SUCCESS,
        OP_LOGE(context, "Acosh: GetPlatformInfo error"), return ge::GRAPH_FAILED);

    // 2. shape & dtype
    int64_t totalNum = 0;
    ge::DataType dtype = ge::DT_UNDEFINED;
    OP_CHECK_IF(ParseInputAndCheckDtype(context, &totalNum, &dtype) != ge::GRAPH_SUCCESS,
        OP_LOGE(context, "Acosh: ParseInputAndCheckDtype error"), return ge::GRAPH_FAILED);

    // 3. workspace
    size_t* ws = context->GetWorkspaceSizes(WORKSPACE_NUM);
    OP_CHECK_NULL_WITH_CONTEXT(context, ws);
    ws[0] = WS_SYS_SIZE;

    // 4. TilingData
    AcoshTilingData* tiling = context->GetTilingData<AcoshTilingData>();
    OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
    OP_CHECK_IF(memset_s(tiling, sizeof(AcoshTilingData), 0, sizeof(AcoshTilingData)) != EOK,
        OP_LOGE(context, "Acosh: memset tiling failed"), return ge::GRAPH_FAILED);

    // 5. 空 Tensor 早返回:Tiling 层 SetBlockDim(1),Kernel 内 Process() 早返回
    if (totalNum == 0) {
        context->SetBlockDim(1);
        ASCENDC_TPL_SEL_PARAM(context, static_cast<uint32_t>(dtype));
        return ge::GRAPH_SUCCESS;
    }

    // 6. 多核切分(按 32B 向上对齐,避免相邻核 CopyOut 写覆盖)
    int64_t ubBlockSize = GetUbBlockSize(context);  // 通常 32 字节
    tiling->totalNum = totalNum;
    tiling->blockFactor = CeilAlign(CeilDiv(totalNum, coreNum), ubBlockSize);
    int64_t usedCoreNum = CeilDiv(totalNum, tiling->blockFactor);

    // 7. UB 切分
    OP_CHECK_IF(ComputeUbFactor(context, ubSize, ubBlockSize, &tiling->ubFactor) != ge::GRAPH_SUCCESS,
        OP_LOGE(context, "Acosh: ComputeUbFactor error"), return ge::GRAPH_FAILED);

    // 8. BlockDim
    context->SetBlockDim(usedCoreNum);

    // 9. TilingKey:按 dtype 维度
    ASCENDC_TPL_SEL_PARAM(context, static_cast<uint32_t>(dtype));

    OP_LOGI(context, "Acosh: totalNum=%ld, blockFactor=%ld, ubFactor=%ld, usedCoreNum=%ld, dtype=%d",
            tiling->totalNum, tiling->blockFactor, tiling->ubFactor, usedCoreNum,
            static_cast<int>(dtype));
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus TilingParseForAcosh([[maybe_unused]] gert::TilingParseContext* context)
{
    return ge::GRAPH_SUCCESS;
}

struct AcoshCompileInfo {};  // 占位,入图场景依赖

IMPL_OP_OPTILING(Acosh).Tiling(AcoshTilingFunc).TilingParse<AcoshCompileInfo>(TilingParseForAcosh);

} // namespace optiling