/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file lerp_tiling_arch35.cpp
 * \brief
 */

#include "lerp_tiling_arch35.h"
#include "log/log.h"
#include "platform/platform_info.h"


using namespace AscendC;
using namespace ge;

namespace optiling {

static constexpr uint64_t OP_KEY_INVALID = 0;
static constexpr uint64_t OP_KEY_1 = 1;
static constexpr uint64_t OP_KEY_2 = 2;
static constexpr uint64_t OP_KEY_3 = 3;
static constexpr uint64_t OP_KEY_4 = 4;
static constexpr uint64_t OP_KEY_5 = 5;
static constexpr uint64_t INDEX_0 = 0;
static constexpr uint64_t INDEX_1 = 1;
static constexpr uint64_t INDEX_2 = 2;
static constexpr uint64_t INDEX_3 = 3;
static constexpr uint64_t WORKSPACE_SIZE = 32;

ge::graphStatus LerpTiling::GetPlatformInfo()
{
    auto platformInfo = context_->GetPlatformInfo();
    if (platformInfo == nullptr) {
        auto compileInfoPtr = static_cast<const LerpCompileInfo*>(context_->GetCompileInfo());
        OP_CHECK_IF(compileInfoPtr == nullptr, OP_LOGE(context_, "compile info is null"),
                        return ge::GRAPH_FAILED);
        coreNum = compileInfoPtr->coreNum;
        ubSize = compileInfoPtr->ubSize;
    } else {
        auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
        coreNum = ascendcPlatform.GetCoreNumAiv();
        uint64_t ubSizePlatForm;
        ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm);
        ubSize = ubSizePlatForm;
    }
    return ge::GRAPH_SUCCESS;
}

uint64_t LerpTiling::GetOpKey(ge::DataType startDtype, ge::DataType endDtype, ge::DataType weightDtype,
                              ge::DataType yDtype)
{
    bool opKey1Flag =
        startDtype == DT_FLOAT16 && endDtype == DT_FLOAT16 && weightDtype == DT_FLOAT16 && yDtype == DT_FLOAT16;
    if (opKey1Flag) {
        return OP_KEY_1;
    }
    bool opKey2Flag = startDtype == DT_FLOAT && endDtype == DT_FLOAT && weightDtype == DT_FLOAT && yDtype == DT_FLOAT;
    if (opKey2Flag) {
        return OP_KEY_2;
    }
    bool opKey3Flag = startDtype == DT_BF16 && endDtype == DT_BF16 && weightDtype == DT_BF16 && yDtype == DT_BF16;
    if (opKey3Flag) {
        return OP_KEY_3;
    }
    bool opKey4Flag = startDtype == DT_BF16 && endDtype == DT_BF16 && weightDtype == DT_FLOAT && yDtype == DT_BF16;
    if (opKey4Flag) {
        return OP_KEY_4;
    }
    bool opKey5Flag =
        startDtype == DT_FLOAT16 && endDtype == DT_FLOAT16 && weightDtype == DT_FLOAT && yDtype == DT_FLOAT16;
    if (opKey5Flag) {
        return OP_KEY_5;
    }

    return OP_KEY_INVALID;
}

uint64_t LerpTiling::GenerateTilingKey(uint64_t innerKey)
{
    return opKey * Ops::Base::BROADCAST_OP_KEY_OFFSET + innerKey;
}

std::map<uint64_t, Ops::Base::BroadcastComputeParams> LerpTiling::GetComputeMap(uint64_t inputOpKey)
{
    Ops::Base::BroadcastComputeParams computeParams0;
    switch (inputOpKey) {
        case OP_KEY_1:
            computeParams0.maxDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS32_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS1_SIZE);
            computeParams0.extraSize = {0, 0};
            computeParams0.bufferDivisor = {128, 128};
            return {{1, computeParams0}};
        case OP_KEY_2:
            computeParams0.maxDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS32_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS1_SIZE);
            computeParams0.extraSize = {0, 0};
            computeParams0.bufferDivisor = {256, 256};
            return {{1, computeParams0}};
        case OP_KEY_3:
            computeParams0.maxDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS32_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS1_SIZE);
            computeParams0.extraSize = {0, 0};
            computeParams0.bufferDivisor = {128, 128};
            return {{1, computeParams0}};
        case OP_KEY_4:
            computeParams0.maxDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS32_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS1_SIZE);
            computeParams0.extraSize = {0, 0};
            computeParams0.bufferDivisor = {160, 160};
            return {{1, computeParams0}};
        case OP_KEY_5:
            computeParams0.maxDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS32_SIZE);
            computeParams0.minDtypeBits = static_cast<int64_t>(Ops::Base::BROADCAST_BITS_SIZE::BITS1_SIZE);
            computeParams0.extraSize = {0, 0};
            computeParams0.bufferDivisor = {160, 160};
            return {{1, computeParams0}};
        default:
            return {};
    }
}

ge::graphStatus LerpTiling::GetShapeAttrsInfo()
{
    auto start = context_->GetInputDesc(INDEX_0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, start);
    auto startDtype = start->GetDataType();
    auto end = context_->GetInputDesc(INDEX_1);
    OP_CHECK_NULL_WITH_CONTEXT(context_, end);
    auto endDtype = end->GetDataType();
    auto weight = context_->GetInputDesc(INDEX_2);
    OP_CHECK_NULL_WITH_CONTEXT(context_, weight);
    auto weightDtype = weight->GetDataType();
    auto y = context_->GetOutputDesc(INDEX_0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, y);
    auto yDtype = y->GetDataType();

    opKey = GetOpKey(startDtype, endDtype, weightDtype, yDtype);
    OP_CHECK_IF((opKey == OP_KEY_INVALID),
                    OP_LOGE(context_->GetNodeName(), "can not get opKey"),
                    return ge::GRAPH_FAILED);
    return ge::GRAPH_SUCCESS;
}

bool LerpTiling::IsCapable()
{
    return true;
}

ge::graphStatus LerpTiling::DoOpTiling()
{
    Ops::Base::BroadcastTilingParams broadcastTilingParams;
    for (uint64_t i = 0; i < context_->GetComputeNodeInputNum(); i++) {
        auto shape = context_->GetInputShape(i);
        OP_CHECK_NULL_WITH_CONTEXT(context_, shape);
        broadcastTilingParams.inShape.push_back(Ops::Base::EnsureNotScalar(shape->GetStorageShape()));
    }
    auto outputShape = context_->GetOutputShape(0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, outputShape);
    broadcastTilingParams.outShape = Ops::Base::EnsureNotScalar(outputShape->GetStorageShape());
    broadcastTilingParams.computeMap = GetComputeMap(opKey);
    broadcastTilingParams.coreNum = coreNum;
    broadcastTilingParams.ubSize = ubSize;

    Ops::Base::BroadcastTilingData broadcastTilingData;
    ge::graphStatus status = BroadcastTiling(broadcastTilingParams, broadcastTilingData);
    if (status != ge::GRAPH_SUCCESS) {
        OP_LOGE(context_->GetNodeName(), "broadcast tiling failed.");
        return ge::GRAPH_FAILED;
    }

    tilingKey_ = GenerateTilingKey(broadcastTilingData.innerKey);
    blockNum = broadcastTilingData.blockNum;
    tilingData.set_blockFormer(broadcastTilingData.blockFormer);
    tilingData.set_ubFormer(broadcastTilingData.ubFormer);
    tilingData.set_ubOuter(broadcastTilingData.ubOuter);
    tilingData.set_ubTail(broadcastTilingData.ubTail);
    tilingData.set_blockTail(broadcastTilingData.blockTail);
    tilingData.set_shapeLen(broadcastTilingData.shapeLen);
    tilingData.set_ubSplitAxis(broadcastTilingData.ubSplitAxis);
    tilingData.set_dimProductBeforeUbInner(broadcastTilingData.dimProductBeforeUbInner);
    tilingData.set_elemNum(broadcastTilingData.elemNum);

    std::copy(broadcastTilingData.dims[INDEX_0].begin(), broadcastTilingData.dims[INDEX_0].end(), input0Dims);
    tilingData.set_input0Dims(input0Dims);
    std::copy(broadcastTilingData.dims[INDEX_1].begin(), broadcastTilingData.dims[INDEX_1].end(), input1Dims);
    tilingData.set_input1Dims(input1Dims);
    std::copy(broadcastTilingData.dims[INDEX_2].begin(), broadcastTilingData.dims[INDEX_2].end(), input2Dims);
    tilingData.set_input2Dims(input2Dims);
    std::copy(broadcastTilingData.dims[INDEX_3].begin(), broadcastTilingData.dims[INDEX_3].end(), outputDims);
    tilingData.set_outputDims(outputDims);
    std::copy(broadcastTilingData.strides[INDEX_0].begin(), broadcastTilingData.strides[INDEX_0].end(), input0Strides);
    tilingData.set_input0Strides(input0Strides);
    std::copy(broadcastTilingData.strides[INDEX_1].begin(), broadcastTilingData.strides[INDEX_1].end(), input1Strides);
    tilingData.set_input1Strides(input1Strides);
    std::copy(broadcastTilingData.strides[INDEX_2].begin(), broadcastTilingData.strides[INDEX_2].end(), input2Strides);
    tilingData.set_input2Strides(input2Strides);
    std::copy(broadcastTilingData.strides[INDEX_3].begin(), broadcastTilingData.strides[INDEX_3].end(), outputStrides);
    tilingData.set_outputStrides(outputStrides);

    return ge::GRAPH_SUCCESS;
}

std::string LerpTiling::ToString(LerpTilingData& inputTilingData)
{
    std::string str;
    str += " blockFormer:" + std::to_string(inputTilingData.get_blockFormer());
    str += " ubFormer:" + std::to_string(inputTilingData.get_ubFormer());
    str += " ubOuter:" + std::to_string(inputTilingData.get_ubOuter());
    str += " ubTail:" + std::to_string(inputTilingData.get_ubTail());
    str += " blockTail:" + std::to_string(inputTilingData.get_blockTail());
    str += " shapeLen:" + std::to_string(inputTilingData.get_shapeLen());
    str += " ubSplitAxis:" + std::to_string(inputTilingData.get_ubSplitAxis());
    str += " dimProductBeforeUbInner:" + std::to_string(inputTilingData.get_dimProductBeforeUbInner());
    str += " elemNum:" + std::to_string(inputTilingData.get_elemNum());
    return str;
}

ge::graphStatus LerpTiling::DoLibApiTiling()
{
    return ge::GRAPH_SUCCESS;
}

uint64_t LerpTiling::GetTilingKey() const
{
    return tilingKey_;
}

ge::graphStatus LerpTiling::GetWorkspaceSize()
{
    workspaceSize_ = WORKSPACE_SIZE;
    return ge::GRAPH_SUCCESS;
}

ge::graphStatus LerpTiling::PostTiling()
{
    context_->SetTilingKey(GetTilingKey());
    context_->SetBlockDim(blockNum);
    size_t* workspaces = context_->GetWorkspaceSizes(1);
    workspaces[0] = workspaceSize_;
    tilingData.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
    context_->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
    OP_LOGI(context_, "TilingInfo: %s.", ToString(tilingData).c_str());
    return ge::GRAPH_SUCCESS;
}

ge::graphStatus TilingForLerp(gert::TilingContext* context)
{
    LerpTiling tiling(context);
    return tiling.DoTiling();
}

ge::graphStatus TilingPrepareForLerp(gert::TilingParseContext* context)
{
    auto compileInfoPtr = context->GetCompiledInfo<LerpCompileInfo>();
    OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
    fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
    OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    compileInfoPtr->coreNum = ascendcPlatform.GetCoreNumAiv();
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
    return ge::GRAPH_SUCCESS;
}

IMPL_OP_OPTILING(Lerp).Tiling(TilingForLerp).TilingParse<LerpCompileInfo>(TilingPrepareForLerp);
}  // namespace optiling