/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file test_axiscombine.cpp
 * \brief
 */

#include <gtest/gtest.h>
#include "interface/function/function.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "passes/tile_graph_pass/graph_constraint/axis_combine.h"
#include "passes/pass_mgr/pass_manager.h"
#include "interface/configs/config_manager.h"
#include "computational_graph_builder.h"
#include <nlohmann/json.hpp>
#include <fstream>
#include <vector>
#include <string>

using namespace npu::tile_fwk;

class TestAxisCombine : public ::testing::Test {
public:
    static void SetUpTestCase() {}
    static void TearDownTestCase() {}
    void SetUp() override
    {
        Program::GetInstance().Reset();
        config::Reset();
        config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);
    }
    void TearDown() override {}
};

constexpr int64_t K_1 = 1;
constexpr int64_t K_2 = 2;
constexpr int64_t K_4 = 4;
constexpr int64_t K_8 = 8;
constexpr int64_t K_16 = 16;
constexpr int64_t K_32 = 32;
constexpr int64_t K_64 = 64;
constexpr int64_t K_128 = 128;

TEST_F(TestAxisCombine, Test1)
{
    ComputationalGraphBuilder graph;
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {4, 127}, "t1"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {4, 1}, "t2"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {4, 127}, "t3"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_ADD, {"t1", "t2"}, {"t3"}, "add", true), true);
    auto* rootFuncPtr = graph.GetFunction();
    rootFuncPtr->paramConfigs_.combineAxis = true;
    AxisCombine pass;
    EXPECT_EQ(pass.RunOnFunction(*rootFuncPtr), SUCCESS);
    auto updatedOperations = rootFuncPtr->Operations();
    int64_t brcbCnt = 0;
    for (const auto& op : updatedOperations) {
        if (op.GetOpcode() == Opcode::OP_BRCB) {
            ++brcbCnt;
            auto outputTensor = op.GetOOperands()[0];
            EXPECT_TRUE(outputTensor->GetConsumers().size() != 0);
        }
        if (op.HasAttr(OpAttributeKey::brcbIdx)) {
            auto idx = op.GetIntAttribute(OpAttributeKey::brcbIdx) - 1;
            auto tensor = op.GetIOperands()[idx];
            EXPECT_TRUE(tensor != nullptr);
            EXPECT_EQ(tensor->shape[0], K_4);
            EXPECT_EQ(tensor->shape[1], K_8);
        }
    }
    EXPECT_EQ(brcbCnt, K_1);
}

TEST_F(TestAxisCombine, Test2)
{
    ComputationalGraphBuilder graph;
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {4, 128}, "t1"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {4, 1}, "t2"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {4, 128}, "t3"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_ROWSUM_SINGLE, {"t1"}, {"t2"}, "rowmax", true), true);
    graph.GetOp("rowmax")->SetAttribute(OP_ATTR_PREFIX + "AXIS", 1);
    EXPECT_EQ(graph.AddOp(Opcode::OP_SUB, {"t1", "t2"}, {"t3"}, "add", true), true);
    auto* rootFuncPtr = graph.GetFunction();
    AxisCombine pass;
    rootFuncPtr->paramConfigs_.combineAxis = true;
    EXPECT_EQ(pass.RunOnFunction(*rootFuncPtr), SUCCESS);
    auto updatedOperations = rootFuncPtr->Operations();
    int64_t cnt = 0;
    for (const auto& op : updatedOperations) {
        if (op.GetOpcode() == Opcode::OP_BRCB) {
            ++cnt;
            auto outputTensor = op.GetOOperands()[0];
            EXPECT_TRUE(outputTensor->GetConsumers().size() != 0);
        }
        if (op.HasAttr(OpAttributeKey::brcbIdx)) {
            auto idx = op.GetIntAttribute(OpAttributeKey::brcbIdx) - 1;
            auto tensor = op.GetIOperands()[idx];
            EXPECT_TRUE(tensor != nullptr);
            EXPECT_EQ(tensor->shape[1], K_8);
            EXPECT_EQ(tensor->shape[0], K_4);
        }
    }
    EXPECT_EQ(cnt, K_1);
}

TEST_F(TestAxisCombine, Test3)
{
    ComputationalGraphBuilder graph;
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {16, 128}, "t1"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {16, 1}, "t2"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_ROWMAX_SINGLE, {"t1"}, {"t2"}, "max", true), true);
    graph.GetOp("max")->SetAttribute(OP_ATTR_PREFIX + "AXIS", 1);

    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {16, 1}, "t3"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {16, 1}, "t4"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_ADD, {"t2", "t3"}, {"t4"}, "add1", true), true);

    // left
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {16, 16}, "t5"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {16, 16}, "t6"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_ADD, {"t2", "t5"}, {"t6"}, "add2", true), true);

    auto* rootFuncPtr = graph.GetFunction();
    rootFuncPtr->paramConfigs_.combineAxis = true;
    AxisCombine pass;
    EXPECT_EQ(pass.RunOnFunction(*rootFuncPtr), SUCCESS);
    // ================== Verify Pass Effect ==================
    auto updatedOperations = rootFuncPtr->Operations();
    for (const auto& op : updatedOperations) {
        if (op.GetOpcode() == Opcode::OP_BRCB) {
            auto outputTensor = op.GetOOperands()[0];
            EXPECT_TRUE(outputTensor->GetConsumers().size() != 0);
        }
        if (op.HasAttr(OpAttributeKey::brcbIdx)) {
            auto idx = op.GetIntAttribute(OpAttributeKey::brcbIdx) - 1;
            auto tensor = op.GetIOperands()[idx];
            EXPECT_TRUE(tensor != nullptr);
            EXPECT_EQ(tensor->shape[0], K_16);
            EXPECT_EQ(tensor->shape[1], K_8);
            EXPECT_EQ(tensor->GetRawTensor()->GetRawShape()[0], K_16);
            EXPECT_EQ(tensor->GetRawTensor()->GetRawShape()[1], K_8);
        }
    }
}

// Skip insert when Both inputs have last dim shape of 1.
TEST_F(TestAxisCombine, Test4)
{
    ComputationalGraphBuilder graph;
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {-1, 1}, "t1"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {2, 1}, "t2"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_COPY_IN, {"t1"}, {"t2"}, "c1", true), true);

    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {-1, 1}, "t3"), true);
    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {1, 1}, "t4"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_COPY_IN, {"t3"}, {"t4"}, "c2", true), true);

    EXPECT_EQ(graph.AddTensor(DataType::DT_FP32, {2, 1}, "t5"), true);
    EXPECT_EQ(graph.AddOp(Opcode::OP_EXPANDEXPDIF, {"t2", "t4"}, {"t5"}, "expanddif", true), true);

    auto* rootFuncPtr = graph.GetFunction();
    rootFuncPtr->paramConfigs_.combineAxis = true;
    AxisCombine pass;
    EXPECT_EQ(pass.RunOnFunction(*rootFuncPtr), SUCCESS);
    // ================== Verify Pass Effect ==================
    int cnt = 0;
    for (const auto& op : rootFuncPtr->Operations()) {
        if (op.GetOpcode() == Opcode::OP_EXPAND || op.GetOpcode() == Opcode::OP_BRCB) {
            ++cnt;
        }
    }
    EXPECT_EQ(cnt, 0);
}

TEST_F(TestAxisCombine, TestDD)
{
    config::SetOperationOption(KEY_COMBINE_AXIS, true);
    config::SetHostOption(COMPILE_STAGE, CS_EXECUTE_GRAPH);

    TileShape::Current().SetVecTile(K_1, K_1, K_32, K_32);
    std::vector<int64_t> tshape = {K_2, K_2, K_64, K_64};

    Tensor T(DT_FP32, tshape, "T");
    Tensor d;
    Tensor output;
    FUNCTION("Test")
    {
        d = SoftmaxNew(T);
        output = Amax(d, -1, true);
    }

    auto funcMap = Program::GetInstance().GetFunctionMap();
}