* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
*/
#include <vector>
#include "gtest/gtest.h"
#include "operator/union/union.h"
#include "vector/vector_helper.h"
#include "util/test_util.h"
using namespace omniruntime::op;
using namespace omniruntime::vec;
using namespace std;
using namespace omniruntime::TestUtil;
namespace UnionTest {
TEST(NativeOmniUnionOperator, TestUnionByThreeColumn)
{
const int32_t dataSize = 6;
int32_t data1[dataSize] = {0, 1, 2, 0, 1, 2};
double data2[dataSize] = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1};
int16_t data3[dataSize] = {6, 5, 4, 3, 2, 1};
int32_t data4[dataSize] = {10, 11, 12, 10, 11, 12};
double data5[dataSize] = {16.6, 15.5, 14.4, 13.3, 12.2, 11.1};
int16_t data6[dataSize] = {16, 15, 14, 13, 12, 11};
std::vector<DataTypePtr> types = { IntType(), DoubleType(), ShortType() };
DataTypes sourceTypes(types);
VectorBatch *vecBatch1 = CreateVectorBatch(sourceTypes, dataSize, data1, data2, data3);
VectorBatch *vecBatch2 = CreateVectorBatch(sourceTypes, dataSize, data4, data5, data6);
UnionOperatorFactory *operatorFactory =
UnionOperatorFactory::CreateUnionOperatorFactory(sourceTypes, sourceTypes.GetSize(), false);
UnionOperator *unionOperator = dynamic_cast<UnionOperator *>(CreateTestOperator(operatorFactory));
unionOperator->AddInput(vecBatch1);
unionOperator->AddInput(vecBatch2);
std::vector<VectorBatch *> outputVecBatches;
while (unionOperator->GetStatus() != OMNI_STATUS_FINISHED) {
VectorBatch *outputVecBatch = nullptr;
unionOperator->GetOutput(&outputVecBatch);
outputVecBatches.push_back(outputVecBatch);
}
int32_t expData1[dataSize] = {0, 1, 2, 0, 1, 2};
double expData2[dataSize] = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1};
int16_t expData3[dataSize] = {6, 5, 4, 3, 2, 1};
int32_t expData4[dataSize] = {10, 11, 12, 10, 11, 12};
double expData5[dataSize] = {16.6, 15.5, 14.4, 13.3, 12.2, 11.1};
int16_t expData6[dataSize] = {16, 15, 14, 13, 12, 11};
VectorBatch *expVecBatch1 = CreateVectorBatch(sourceTypes, dataSize, expData1, expData2, expData3);
VectorBatch *expVecBatch2 = CreateVectorBatch(sourceTypes, dataSize, expData4, expData5, expData6);
EXPECT_EQ(outputVecBatches.size(), 2);
EXPECT_TRUE(VecBatchMatch(outputVecBatches[0], expVecBatch1));
EXPECT_TRUE(VecBatchMatch(outputVecBatches[1], expVecBatch2));
VectorHelper::FreeVecBatch(expVecBatch1);
VectorHelper::FreeVecBatch(expVecBatch2);
VectorHelper::FreeVecBatches(outputVecBatches);
omniruntime::op::Operator::DeleteOperator(unionOperator);
delete operatorFactory;
}
}