/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file floor_tiling_arch35.cpp
 * \brief
 */

#include "register/op_impl_registry.h"
#include "log/log.h"
#include "atvoss/elewise/elewise_tiling.h"
#include "floor_tiling_arch35.h"

using namespace ge;
using namespace Ops::Base;

namespace optiling {

static constexpr uint64_t OP_KEY_INVALID = 0;
static constexpr uint64_t OP_KEY_1 = 1;
static constexpr uint64_t OP_KEY_2 = 2;
static constexpr uint64_t OP_KEY_3 = 3;
static constexpr uint64_t INDEX_0 = 0;
static constexpr uint64_t WORKSPACE_SIZE = 32;

ge::graphStatus FloorTiling::GetPlatformInfo()
{
    auto platformInfo = context_->GetPlatformInfo();
    if (platformInfo == nullptr) {
        auto compileInfoPtr = reinterpret_cast<const FloorCompileInfo*>(context_->GetCompileInfo());
        OP_CHECK_IF(compileInfoPtr == nullptr, OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "compile_info", "nullptr", "compile info is null"), return ge::GRAPH_FAILED);
        coreNum = compileInfoPtr->coreNum;
        ubSize = compileInfoPtr->ubSize;
    } else {
        auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
        coreNum = ascendcPlatform.GetCoreNumAiv();
        uint64_t ubSizePlatForm = 0;
        ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
        ubSize = ubSizePlatForm;
    }
    return ge::GRAPH_SUCCESS;
}

uint64_t FloorTiling::GetOpKey(ge::DataType xDtype, ge::DataType yDtype)
{
    bool opKey1Flag = xDtype == DT_BF16 && yDtype == DT_BF16;
    if (opKey1Flag) {
        return OP_KEY_1;
    }
    bool opKey2Flag = xDtype == DT_FLOAT16 && yDtype == DT_FLOAT16;
    if (opKey2Flag) {
        return OP_KEY_2;
    }
    bool opKey3Flag = xDtype == DT_FLOAT && yDtype == DT_FLOAT;
    if (opKey3Flag) {
        return OP_KEY_3;
    }

    return OP_KEY_INVALID;
}

uint64_t FloorTiling::GenerateTilingKey(uint64_t innerKey)
{
    return opKey * OP_KEY_OFFSET + innerKey;
}

std::map<uint64_t, ComputeParams> FloorTiling::GetComputeMap(uint64_t paramOpKey)
{
    ComputeParams computeParams0;
    switch (paramOpKey) {
        case OP_KEY_1:
            computeParams0.maxDtypeBits = static_cast<int64_t>(BITS_SIZE::BITS16_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(BITS_SIZE::BITS16_SIZE);
            computeParams0.extraSize = {0};
            computeParams0.bufferDivisor = {64};
            return {{0, computeParams0}};
        case OP_KEY_2:
            computeParams0.maxDtypeBits = static_cast<int64_t>(BITS_SIZE::BITS16_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(BITS_SIZE::BITS16_SIZE);
            computeParams0.extraSize = {0};
            computeParams0.bufferDivisor = {64};
            return {{0, computeParams0}};
        case OP_KEY_3:
            computeParams0.maxDtypeBits = static_cast<int64_t>(BITS_SIZE::BITS32_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(BITS_SIZE::BITS32_SIZE);
            computeParams0.extraSize = {0};
            computeParams0.bufferDivisor = {128};
            return {{0, computeParams0}};
        default:
            return {};
    }
}

ge::graphStatus FloorTiling::GetShapeAttrsInfo()
{
    auto x = context_->GetInputDesc(INDEX_0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, x);
    auto xDtype = x->GetDataType();
    auto y = context_->GetOutputDesc(INDEX_0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, y);
    auto yDtype = y->GetDataType();

    opKey = GetOpKey(xDtype, yDtype);
    OP_CHECK_IF(
        (opKey == OP_KEY_INVALID), OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "opKey", std::to_string(opKey), "can not get opKey"), return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

bool FloorTiling::IsCapable()
{
    return true;
}

ge::graphStatus FloorTiling::DoOpTiling()
{
    auto xShape = context_->GetInputShape(INDEX_0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, xShape);

    ElewiseTilingParams elewiseTilingParams;
    elewiseTilingParams.shape = xShape->GetStorageShape();
    elewiseTilingParams.computeMap = GetComputeMap(opKey);
    elewiseTilingParams.coreNum = coreNum;
    elewiseTilingParams.ubSize = ubSize;

    ElewiseTilingData elewiseTilingData;
    auto status = ElewiseTiling(elewiseTilingParams, elewiseTilingData);
    OP_CHECK_IF(
        (status == ge::GRAPH_FAILED), OP_LOGE(context_->GetNodeName(), "elewise tiling failed"),
        return ge::GRAPH_FAILED);

    tilingKey_ = GenerateTilingKey(elewiseTilingData.innerKey);
    blockNum = elewiseTilingData.blockNum;
    tilingData.set_dim0(elewiseTilingData.dim0);
    tilingData.set_blockFormer(elewiseTilingData.blockFormer);
    tilingData.set_ubFormer(elewiseTilingData.ubFormer);
    tilingData.set_ubLoopOfFormerBlock(elewiseTilingData.ubLoopOfFormerBlock);
    tilingData.set_ubLoopOfTailBlock(elewiseTilingData.ubLoopOfTailBlock);
    tilingData.set_ubTailOfFormerBlock(elewiseTilingData.ubTailOfFormerBlock);
    tilingData.set_ubTailOfTailBlock(elewiseTilingData.ubTailOfTailBlock);
    tilingData.set_elemNum(elewiseTilingData.elemNum);

    return ge::GRAPH_SUCCESS;
}

std::string FloorTiling::ToString(FloorTilingData& paramTilingData)
{
    std::string str;
    str += " dim0:" + std::to_string(paramTilingData.get_dim0());
    str += " blockFormer:" + std::to_string(paramTilingData.get_blockFormer());
    str += " ubFormer:" + std::to_string(paramTilingData.get_ubFormer());
    str += " ubLoopOfFormerBlock:" + std::to_string(paramTilingData.get_ubLoopOfFormerBlock());
    str += " ubLoopOfTailBlock:" + std::to_string(paramTilingData.get_ubLoopOfTailBlock());
    str += " ubTailOfFormerBlock:" + std::to_string(paramTilingData.get_ubTailOfFormerBlock());
    str += " ubTailOfTailBlock:" + std::to_string(paramTilingData.get_ubTailOfTailBlock());
    str += " elemNum:" + std::to_string(paramTilingData.get_elemNum());
    return str;
}

ge::graphStatus FloorTiling::DoLibApiTiling()
{
    return ge::GRAPH_SUCCESS;
}

uint64_t FloorTiling::GetTilingKey() const
{
    return tilingKey_;
}

ge::graphStatus FloorTiling::GetWorkspaceSize()
{
    workspaceSize_ = WORKSPACE_SIZE;
    return ge::GRAPH_SUCCESS;
}

ge::graphStatus FloorTiling::PostTiling()
{
    context_->SetTilingKey(GetTilingKey());
    context_->SetBlockDim(blockNum);
    size_t* workspaces = context_->GetWorkspaceSizes(1);
    OP_CHECK_IF(workspaces == nullptr, OP_LOGE_FOR_INVALID_VALUE_WITH_REASON(context_->GetNodeName(), "workspace", "nullptr", "workspace is null"), return ge::GRAPH_FAILED);
    workspaces[0] = workspaceSize_;
    tilingData.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
    context_->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
    OP_LOGI(context_, "TilingInfo: %s.", ToString(tilingData).c_str());
    return ge::GRAPH_SUCCESS;
}

static ge::graphStatus TilingForFloor(gert::TilingContext* context)
{
    auto compileInfo = context->GetCompileInfo<FloorCompileInfo>();
    OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
    FloorTiling tiling(context);
    return tiling.DoTiling();
}

static ge::graphStatus TilingPrepareForFloor(gert::TilingParseContext* context)
{
    auto compileInfoPtr = context->GetCompiledInfo<FloorCompileInfo>();
    OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);

    fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
    OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    compileInfoPtr->coreNum = ascendcPlatform.GetCoreNumAiv();
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
    return ge::GRAPH_SUCCESS;
}

IMPL_OP_OPTILING(Floor).Tiling(TilingForFloor).TilingParse<FloorCompileInfo>(TilingPrepareForFloor);

} // namespace optiling