/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
 * Description: JNI Operator Factory Source File
 */

#include "compute/driver.h"
#include "compute/task.h"
#include "compute/local_planner.h"
#include "operator/operator.h"
#include "operator/operator_factory.h"
#include "gtest/gtest.h"
#include "test/util/test_util.h"
#include "util/config/QueryConfig.h"

using namespace omniruntime;
using namespace TestUtil;

namespace UnionTest {
class TestBatchIterator : public ColumnarBatchIterator {
public:
    TestBatchIterator(const std::vector<VectorBatch*> &date): date(date) {}

    ~TestBatchIterator() override = default;

    VectorBatch *Next() override
    {
        if (index < date.size()) {
            return date[index++];
        } else {
            return nullptr;
        }
    }

private:
    std::vector<VectorBatch*> date;
    size_t index = 0;
};

VectorBatch *CreateTestUnionByThreeColumnVecBatch1()
{
    const int32_t dataSize = 6;
    int32_t data1[dataSize] = {0, 1, 2, 0, 1, 2};
    double data2[dataSize] = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1};
    int16_t data3[dataSize] = {6, 5, 4, 3, 2, 1};

    std::vector<DataTypePtr> types = { IntType(), DoubleType(), ShortType() };
    DataTypes sourceTypes(types);

    return CreateVectorBatch(sourceTypes, dataSize, data1, data2, data3);
}

VectorBatch *CreateTestUnionByThreeColumnVecBatch2()
{
    const int32_t dataSize = 6;
    int32_t data1[dataSize] = {10, 11, 12, 10, 11, 12};
    double data2[dataSize] = {16.6, 15.5, 14.4, 13.3, 12.2, 11.1};
    int16_t data3[dataSize] = {16, 15, 14, 13, 12, 11};

    std::vector<DataTypePtr> types = { IntType(), DoubleType(), ShortType() };
    DataTypes sourceTypes(types);

    return CreateVectorBatch(sourceTypes, dataSize, data1, data2, data3);
}

VectorBatch *CreateTestUnionByThreeColumnOutputVecBatch1()
{
    const int32_t dataSize = 6;
    int32_t expData1[dataSize] = {0, 1, 2, 0, 1, 2};
    double expData2[dataSize] = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1};
    int16_t expData3[dataSize] = {6, 5, 4, 3, 2, 1};

    std::vector<DataTypePtr> types = { IntType(), DoubleType(), ShortType() };
    DataTypes sourceTypes(types);

    return CreateVectorBatch(sourceTypes, dataSize, expData1, expData2, expData3);
}

VectorBatch *CreateTestUnionByThreeColumnOutputVecBatch2()
{
    const int32_t dataSize = 6;
    int32_t expData1[dataSize] = {10, 11, 12, 10, 11, 12};
    double expData2[dataSize] = {16.6, 15.5, 14.4, 13.3, 12.2, 11.1};
    int16_t expData3[dataSize] = {16, 15, 14, 13, 12, 11};

    std::vector<DataTypePtr> types = { IntType(), DoubleType(), ShortType() };
    DataTypes sourceTypes(types);

    return CreateVectorBatch(sourceTypes, dataSize, expData1, expData2, expData3);
}

TEST(PipelineTest, TestUnionByThreeColumn)
{
    std::vector<DataTypePtr> types = { IntType(), DoubleType(), ShortType() };
    VectorBatch *vecBatch1 = CreateTestUnionByThreeColumnVecBatch1();
    VectorBatch *vecBatch2 = CreateTestUnionByThreeColumnVecBatch1();
    std::vector<VectorBatch*> inputVector1;
    std::vector<VectorBatch*> inputVector2;
    inputVector1.push_back(vecBatch1);
    inputVector2.push_back(vecBatch2);

    auto sourceBatchIterator1 = std::make_unique<TestBatchIterator>(inputVector1);
    auto resIterator1 = std::make_shared<ResultIterator>(std::move(sourceBatchIterator1));
    auto outTypes1 = std::make_shared<DataTypes>(types);
    auto valueStreamNode1 = std::make_shared<const ValueStreamNode>("value_stream", outTypes1, resIterator1);

    auto sourceBatchIterator2 = std::make_unique<TestBatchIterator>(inputVector2);
    auto resIterator2 = std::make_shared<ResultIterator>(std::move(sourceBatchIterator2));
    auto outTypes2 = std::make_shared<DataTypes>(types);
    auto valueStreamNode2 = std::make_shared<const ValueStreamNode>("value_stream", outTypes2, resIterator2);

    auto sources = std::vector<PlanNodePtr>{valueStreamNode1, valueStreamNode2};
    auto unionNode = std::make_shared<const UnionNode>("union", sources, false);
    std::unordered_set<PlanNodeId> emptySet;
    PlanFragment planFragment{unionNode, ExecutionStrategy::K_UNGROUPED, 1, emptySet};
    auto task = std::make_shared<OmniTask>(planFragment, config::QueryConfig());
    VectorBatch *vectorBatch1 = nullptr;
    VectorBatch *vectorBatch2 = nullptr;
    while (true) {
        auto future = OmniFuture::makeEmpty();
        auto out1 = task->Next(&future);
        auto out2 = task->Next(&future);
        if (!future.valid()) {
            vectorBatch1 = out1;
            vectorBatch2 = out2;
            break;
        }
        OMNI_CHECK(out1 == nullptr, "Expected to wait but still got non-null output from Omni task");
        OMNI_CHECK(out2 == nullptr, "Expected to wait but still got non-null output from Omni task");
        future.wait();
    }

    VectorBatch *expVecBatch1 = CreateTestUnionByThreeColumnOutputVecBatch1();
    VectorBatch *expVecBatch2 = CreateTestUnionByThreeColumnOutputVecBatch1();

    EXPECT_TRUE(VecBatchMatch(vectorBatch2, expVecBatch1));
    EXPECT_TRUE(VecBatchMatch(vectorBatch1, expVecBatch2));
    VectorHelper::FreeVecBatch(expVecBatch1);
    VectorHelper::FreeVecBatch(expVecBatch2);
    VectorHelper::FreeVecBatch(vectorBatch1);
    VectorHelper::FreeVecBatch(vectorBatch2);
}

}