/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* Generated By CANNBot */

/*!
 * \file cosh_tiling.cpp
 * \brief Cosh 算子 Host Tiling 实现(标准 Ascend C kernel,arch35,DESIGN §3.3)
 *
 * 手写 CoshTilingFunc:
 *   1. 平台信息:coreNum = GetCoreNumAiv(),ubSize = GetCoreMemSize(UB)。
 *   2. shape/dtype:totalNum = x.shape 各维乘积;dtype ∈ {DT_FLOAT, DT_FLOAT16, DT_BF16}(迭代二全放开)。
 *   3. workspace = 0(无自定义 workspace 张量)。
 *   4. 多核切分:blockFactor = CeilAlign(CeilDiv(totalNum, coreNum), 32B),
 *      usedCoreNum = CeilDiv(totalNum, blockFactor),SetBlockDim(usedCoreNum)。
 *   5. UB 切分:按可用 UB 与缓冲份数反推 ubFactor,按 256B 对齐(向量指令最佳发射粒度),
 *      并 clamp 不超过 blockFactor(LOW-1 上界保护,小 shape UB 不过分配);
 *      增加显式 ubFactor × FP32_BUF_COUNT × FP32_TYPE_SIZE ≤ ubSize 双重校验。
 *   6. 空 tensor(0 元素)host 侧直接返回,SetBlockDim(1),不进 Kernel。
 *   7. TilingKey 仅 dtype 维:ASCENDC_TPL_SEL_PARAM(context, dtype) 映射到 kernel D_T_X。
 *
 * ❌ 不使用废弃宏 BEGIN_TILING_DATA_DEF / TILING_KEY_IS。
 * ❌ 不依赖 atvoss(ElewiseBaseTiling / EleBaseTilingData / DAG)。
 */
#include "register/op_def_registry.h"
#include "op_common/log/log.h"
#include "op_common/op_host/util/math_util.h"
#include "op_common/op_host/util/platform_util.h"
#include "../../op_kernel/arch35/cosh_tiling_data.h"
#include "../../op_kernel/arch35/cosh_tiling_key.h"

namespace optiling {

using Ops::Base::CeilDiv;
using Ops::Base::CeilAlign;
using Ops::Base::FloorDiv;
using Ops::Base::FloorAlign;
using Ops::Base::GetUbBlockSize;

constexpr uint32_t WS_SYS_SIZE = 0U;
constexpr size_t WORKSPACE_NUM = 1;

// UB 缓冲份数预算(WithCast 最坏情况,DESIGN §3.9):
//   inQueue/outQueue 双缓冲(按 fp32 4B 折算)+ 3 份 fp32 中间缓冲。
//   按 fp32(4B)口径估算每元素份数,保证所有 dtype 路径下 UB 不溢出。
//   in(2) + out(2) + ax(1) + e1(1) + work(1) = 7 份 fp32 等价缓冲。
constexpr int64_t FP32_BUF_COUNT = 7;
constexpr int64_t FP32_TYPE_SIZE = 4;

// 向量指令最佳对齐粒度(DESIGN §3.3 / §3.9 / §4.1):
//   ubFactor 按 256B 对齐保证向量指令一次发射的最佳 burst 长度,
//   匹配 fp32 64 元素 / fp16/bf16 128 元素的向量单元处理粒度。
constexpr int64_t VEC_ALIGN_BYTES = 256;
constexpr int64_t VEC_ALIGN_FP32_ELEMS = VEC_ALIGN_BYTES / FP32_TYPE_SIZE; // = 64 元素

static const gert::Shape g_vec_1_shape = {1};

static inline const gert::Shape EnsureNotScalar(const gert::Shape& inShape)
{
    if (inShape.GetDimNum() == 0) {
        return g_vec_1_shape;
    }
    return inShape;
}

// 平台信息:ubSize, coreNum
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t* ubSize, int64_t* coreNum)
{
    fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
    OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    *coreNum = ascendcPlatform.GetCoreNumAiv();
    OP_CHECK_IF(*coreNum == 0, OP_LOGE(context, "Cosh: coreNum is 0"), return ge::GRAPH_FAILED);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, *ubSize);
    OP_CHECK_IF(*ubSize == 0, OP_LOGE(context, "Cosh: ubSize is 0"), return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

// shape / dtype 信息:totalNum(展平元素数)、dtype
static ge::graphStatus GetShapeAttrsInfo(gert::TilingContext* context, int64_t* totalNum, ge::DataType* dataType)
{
    auto inputX = context->GetInputShape(0);
    OP_CHECK_NULL_WITH_CONTEXT(context, inputX);
    auto inputShapeX = EnsureNotScalar(inputX->GetStorageShape());
    *totalNum = inputShapeX.GetShapeSize();

    // dtype 校验(迭代二:全 3 dtype,按 spec dtype_policy.supported_combinations)
    const std::set<ge::DataType> supportedDtype = {ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16};
    auto inputDesc = context->GetInputDesc(0);
    OP_CHECK_NULL_WITH_CONTEXT(context, inputDesc);
    *dataType = inputDesc->GetDataType();
    OP_CHECK_IF(supportedDtype.count(*dataType) == 0,
                OP_LOGE(context, "Cosh: invalid dtype %d", static_cast<int>(*dataType)),
                return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
{
    size_t* currentWorkspace = context->GetWorkspaceSizes(WORKSPACE_NUM);
    OP_CHECK_NULL_WITH_CONTEXT(context, currentWorkspace);
    currentWorkspace[0] = WS_SYS_SIZE;
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus CoshTilingFunc(gert::TilingContext* context)
{
    OP_LOGD(context->GetNodeName(), "Enter CoshTilingFunc");
    // 1、平台信息
    uint64_t ubSize = 0;
    int64_t coreNum = 0;
    OP_CHECK_IF(GetPlatformInfo(context, &ubSize, &coreNum) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "Cosh: GetPlatformInfo error"), return ge::GRAPH_FAILED);

    // 2、shape / dtype
    int64_t totalNum = 0;
    ge::DataType dataType = ge::DT_UNDEFINED;
    OP_CHECK_IF(GetShapeAttrsInfo(context, &totalNum, &dataType) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "Cosh: GetShapeAttrsInfo error"), return ge::GRAPH_FAILED);

    // 3、workspace
    OP_CHECK_IF(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS,
                OP_LOGE(context, "Cosh: GetWorkspaceSize error"), return ge::GRAPH_FAILED);

    // 4、TilingData(空间由 kernel REGISTER_TILING_DEFAULT 决定)
    CoshTilingData* tiling = context->GetTilingData<CoshTilingData>();
    OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
    OP_CHECK_IF(memset_s(tiling, sizeof(CoshTilingData), 0, sizeof(CoshTilingData)) != EOK,
                OP_LOGE(context, "Cosh: memset tiling data error"), return ge::GRAPH_FAILED);

    // 空 tensor:直接返回,1 核,不进计算(spec boundary:空 tensor 直返)
    if (totalNum == 0) {
        context->SetBlockDim(1);
        ASCENDC_TPL_SEL_PARAM(context, static_cast<uint32_t>(dataType));
        return ge::GRAPH_SUCCESS;
    }

    // 多核切分:按核数均分,blockFactor 按 UB block(32B)向上对齐,
    // 避免相邻核 CopyOut 写 GM 时 DMA 范围越界覆盖相邻核数据。
    int64_t ubBlockSize = Ops::Base::GetUbBlockSize(context);
    tiling->totalNum = totalNum;
    tiling->blockFactor = CeilAlign(CeilDiv(totalNum, coreNum), ubBlockSize);
    int64_t usedCoreNum = CeilDiv(totalNum, tiling->blockFactor);

    // UB 切分(DESIGN §3.3 / §3.9 / §4.1):
    //   1) 可用 UB 按 fp32 口径除以缓冲份数,得到单 tile 上限元素数;
    //   2) 向下对齐到 256B(向量指令最佳发射粒度,= fp32 64 元素);
    //   3) 上界 clamp:ubFactor 不超过本核负责的 blockFactor(避免 InitBuffer 过分配);
    //      最小不低于 256B 一个 burst(fp32 64 元素),仍 ≤ ubSize/FP32_BUF_COUNT 上界,
    //      由 DataCopyPad 处理末块非满 burst(功能等价,单核数据小时退化为 1 次循环)。
    int64_t ubBudgetElems = FloorDiv(static_cast<int64_t>(ubSize) / FP32_TYPE_SIZE, FP32_BUF_COUNT);
    int64_t ubFactorCandidate = FloorAlign(ubBudgetElems, VEC_ALIGN_FP32_ELEMS);

    // LOW-1 上界 clamp:ubFactor ≤ blockFactor(小 shape 下避免 InitBuffer 过分配 UB)。
    //   blockFactor 已按 32B 对齐(≥ 8 fp32 元素),但可能不是 256B 倍数;
    //   若小于 256B,则上对齐到 256B 一个 burst(保证向量指令对齐前提下取最小可行 tile)。
    if (ubFactorCandidate > tiling->blockFactor) {
        int64_t clampedByBlock = CeilAlign(tiling->blockFactor, VEC_ALIGN_FP32_ELEMS);
        ubFactorCandidate = (clampedByBlock < ubFactorCandidate) ? clampedByBlock : ubFactorCandidate;
    }
    tiling->ubFactor = ubFactorCandidate;

    // 显式上界保护:ubFactor × FP32_BUF_COUNT × FP32_TYPE_SIZE ≤ ubSize(防御性双重校验)。
    OP_CHECK_IF(tiling->ubFactor <= 0, OP_LOGE(context, "Cosh: ubFactor <= 0"), return ge::GRAPH_FAILED);
    OP_CHECK_IF(
        static_cast<uint64_t>(tiling->ubFactor) * FP32_BUF_COUNT * FP32_TYPE_SIZE > ubSize,
        OP_LOGE(context, "Cosh: ubFactor=%ld exceeds UB capacity ubSize=%lu",
                tiling->ubFactor, static_cast<unsigned long>(ubSize)),
        return ge::GRAPH_FAILED);

    context->SetBlockDim(usedCoreNum);

    // 5、TilingKey:仅 dtype 维,映射到 kernel 模板参数 D_T_X
    ASCENDC_TPL_SEL_PARAM(context, static_cast<uint32_t>(dataType));
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus TilingParseForCosh([[maybe_unused]] gert::TilingParseContext* context)
{
    return ge::GRAPH_SUCCESS;
}

struct CoshCompileInfo {}; // 必须定义,入图场景依赖

IMPL_OP_OPTILING(Cosh).Tiling(CoshTilingFunc).TilingParse<CoshCompileInfo>(TilingParseForCosh);

} // namespace optiling