/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
 * Description: Type Util Class
 */

#include "test_util.h"
#include <cmath>
#include <cfloat>
#include <cstdarg>
#include <gtest/gtest.h>
#include "vector/vector_helper.h"
#include "type/data_type.h"
#include "type/decimal_operations.h"

using namespace omniruntime::type;
using namespace omniruntime::vec;
using namespace omniruntime::expressions;
using namespace omniruntime::codegen;

namespace omniruntime::TestUtil {
void PrintNotMatchBatches(VectorBatch *outputPages, VectorBatch *expectPage)
{
    printf("================ Expected Vector Batch ==================\n");
    VectorHelper::PrintVecBatch(expectPage);
    printf("================= Result Vector Batch ===================\n");
    VectorHelper::PrintVecBatch(outputPages);
}

bool VecBatchMatch(VectorBatch *outputPages, VectorBatch *expectPage)
{
    if (outputPages->GetRowCount() != expectPage->GetRowCount()) {
        printf("Invalid row count. Expected=%d, actual=%d\n", expectPage->GetRowCount(), outputPages->GetRowCount());
        PrintNotMatchBatches(outputPages, expectPage);
        return false;
    }

    int32_t columnNumber = outputPages->GetVectorCount();
    if (columnNumber != expectPage->GetVectorCount()) {
        printf("Invalid vector count. Expected=%d, actual=%d\n", expectPage->GetVectorCount(),
            outputPages->GetVectorCount());
        PrintNotMatchBatches(outputPages, expectPage);
        return false;
    }
    for (int32_t i = 0; i < columnNumber; i++) {
        if (!ColumnMatch(outputPages->Get(i), expectPage->Get(i))) {
            printf("Vector %d not matched\n", i);
            PrintNotMatchBatches(outputPages, expectPage);
            return false;
        }
    }

    return true;
}

bool VecBatchesIgnoreOrderMatch(std::vector<VectorBatch *> &resultBatches, std::vector<VectorBatch *> &expectedBatches)
{
    if (resultBatches.size() != expectedBatches.size()) {
        printf("List of VectorBatches not match. Expecting %ld, got %ld\n", expectedBatches.size(),
            resultBatches.size());
        printf("================ Expected Vector Batch (%ld) ==================\n", expectedBatches.size());
        for (size_t i = 0; i < expectedBatches.size(); ++i) {
            printf("    ---------- Expected Vector Batch %ld / %ld ----------\n", i, expectedBatches.size());
        }
        printf("================ Result Vector Batch (%ld) ==================\n", resultBatches.size());
        for (size_t i = 0; i < resultBatches.size(); ++i) {
            printf("    ---------- Result Vector Batch %ld / %ld ----------\n", i, resultBatches.size());
        }
        return false;
    }

    for (size_t i = 0; i < resultBatches.size(); i++) {
        if (!VecBatchMatchIgnoreOrder(resultBatches[i], expectedBatches[i])) {
            printf("VectorBatch %ld not match\n", i);
            return false;
        }
    }

    return true;
}

template <typename T> ALWAYS_INLINE T GetValue(BaseVector *vector, uint32_t rowIndex)
{
    if (vector->GetEncoding() != OMNI_DICTIONARY) {
        return static_cast<Vector<T> *>(vector)->GetValue(rowIndex);
    } else {
        return reinterpret_cast<Vector<DictionaryContainer<T>> *>(vector)->GetValue(rowIndex);
    }
}

static ALWAYS_INLINE std::string_view GetVarcharValue(BaseVector *vector, uint32_t rowIndex)
{
    using VarcharVector = Vector<LargeStringContainer<std::string_view>>;
    using DictionaryVector = Vector<DictionaryContainer<std::string_view, LargeStringContainer>>;
    if (vector->GetEncoding() != OMNI_DICTIONARY) {
        return static_cast<VarcharVector *>(vector)->GetValue(rowIndex);
    } else {
        return reinterpret_cast<DictionaryVector *>(vector)->GetValue(rowIndex);
    }
}

static ALWAYS_INLINE bool DoubleValueEqualsValueIgnoreNulls(BaseVector *leftVector, BaseVector *rightVector,
    int32_t rowIndex)
{
    auto leftValue = GetValue<double>(leftVector, rowIndex);
    auto rightValue = GetValue<double>(rightVector, rowIndex);
    if (std::abs(leftValue - rightValue) < __DBL_EPSILON__) {
        return true;
    } else {
        return false;
    }
}

static ALWAYS_INLINE bool VarcharValueEqualsValueIgnoreNulls(BaseVector *leftVector, BaseVector *rightVector,
    int32_t rowIndex)
{
    return GetVarcharValue(leftVector, rowIndex) == GetVarcharValue(rightVector, rowIndex);
}

template <typename T>
ALWAYS_INLINE bool PrimitiveValueEqualsValueIgnoreNulls(BaseVector *leftVector, BaseVector *rightVector,
    int32_t rowIndex)
{
    return GetValue<T>(leftVector, rowIndex) == GetValue<T>(rightVector, rowIndex);
}

template <DataTypeId typeId>
static bool ValueEqualsValueIgnoreNulls(BaseVector *leftVector, BaseVector *rightVector, int32_t rowIndex)
{
    using T = typename NativeType<typeId>::type;
    if constexpr (std::is_same_v<T, std::string_view>) {
        return VarcharValueEqualsValueIgnoreNulls(leftVector, rightVector, rowIndex);
    } else if constexpr (std::is_same_v<T, double>) {
        return DoubleValueEqualsValueIgnoreNulls(leftVector, rightVector, rowIndex);
    } else {
        return PrimitiveValueEqualsValueIgnoreNulls<T>(leftVector, rightVector, rowIndex);
    }
}

bool ColumnMatch(BaseVector *actualColumn, BaseVector *expectColumn)
{
    if (actualColumn->GetSize() != expectColumn->GetSize()) {
        return false;
    }

    bool result = true;
    DataTypeId typeId = expectColumn->GetTypeId();
    for (int32_t rowIndex = 0; rowIndex < actualColumn->GetSize(); rowIndex++) {
        if (actualColumn->IsNull(rowIndex) != expectColumn->IsNull(rowIndex)) {
            return false;
        }

        // all is null
        if ((actualColumn->IsNull(rowIndex) == expectColumn->IsNull(rowIndex)) && actualColumn->IsNull(rowIndex)) {
            continue;
        }

        if (typeId == OMNI_CONTAINER) {
            auto vecCount = static_cast<ContainerVector *>(expectColumn)->GetVectorCount();
            for (int32_t vecIdx = 0; vecIdx < vecCount; vecIdx++) {
                auto actualVec = static_cast<ContainerVector *>(actualColumn)->GetValue(vecIdx);
                auto expectVec = static_cast<ContainerVector *>(expectColumn)->GetValue(vecIdx);
                result =
                    ColumnMatch(reinterpret_cast<BaseVector *>(actualVec), reinterpret_cast<BaseVector *>(expectVec));
                if (!result) {
                    return false;
                }
            }
        } else if (typeId == OMNI_ARRAY) {
            auto actualElementVector = static_cast<ArrayVector *>(actualColumn)->GetElementVector();
            auto expectElementVector = static_cast<ArrayVector *>(expectColumn)->GetElementVector();
            result =
                    ColumnMatch(reinterpret_cast<BaseVector *>(actualElementVector.get()), reinterpret_cast<BaseVector *>(expectElementVector.get()));
            if (!result) {
                return false;
            }
        } else {
            result = DYNAMIC_TYPE_DISPATCH(ValueEqualsValueIgnoreNulls, typeId, actualColumn, expectColumn, rowIndex);
            if (!result) {
                return false;
            }
        }
    }

    return true;
}

VectorBatch *CreateVectorBatch(const DataTypes &types, int32_t rowCount, ...)
{
    int32_t typesCount = types.GetSize();
    auto *vectorBatch = new VectorBatch(rowCount);
    va_list args;
    va_start(args, rowCount);
    for (int32_t i = 0; i < typesCount; i++) {
        auto &type = types.GetType(i);
        vectorBatch->Append(CreateVector(*type, rowCount, args));
    }
    va_end(args);
    return vectorBatch;
}

VectorBatch *CreateArrayVectorBatch(const DataTypes &types, std::vector<int32_t> &offsets,
                                    int32_t dataSize, int32_t elementSize, ...){
    int32_t typesCount = types.GetSize();
    auto *vectorBatch = new VectorBatch(dataSize);
    va_list args;
    va_start(args, elementSize);
    for (int32_t i = 0; i < typesCount; i++) {
        auto &type = types.GetType(i);
        // Extract element type from ArrayType if it's an ArrayType
        DataType *elementType = &(*type);
        if (type->GetId() == OMNI_ARRAY) {
            auto arrayType = dynamic_cast<ArrayType *>(type.get());
            if (arrayType != nullptr) {
                elementType = arrayType->ElementType().get();
            }
        }
        auto elementVector = std::shared_ptr<BaseVector>(CreateVector(*elementType, elementSize, args));
        auto *arrayVector = new ArrayVector(dataSize, elementVector);
        for (size_t j = 0; j < offsets.size(); j++) {
            arrayVector->SetOffset(j, offsets[j]);
        }
        vectorBatch->Append(arrayVector);
    }
    va_end(args);
    return vectorBatch;
}

void AssertStringEquals(std::vector<std::string> &expected, std::vector<uint8_t *> &result,
    std::vector<int32_t> &outLen)
{
    for (size_t i = 0; i < expected.size(); i++) {
        std::string actual(reinterpret_cast<char *>(result[i]), outLen[i]);
        EXPECT_EQ(actual, expected[i]);
    }
}

void AssertStringEquals(std::vector<std::string> &expected, int32_t offset, int32_t rowCnt,
    std::vector<uint8_t *> &result, std::vector<int32_t> &outLen)
{
    for (int32_t i = 0; i < rowCnt; i++) {
        std::string actual(reinterpret_cast<char *>(result[i]), outLen[i]);
        EXPECT_EQ(actual, expected[i + offset]);
    }
}

void AssertIntEquals(std::vector<int32_t> &expected, std::vector<int32_t> &result)
{
    for (size_t i = 0; i < expected.size(); i++) {
        EXPECT_EQ(result[i], expected[i]);
    }
}

void AssertLongEquals(std::vector<int64_t> &expected, std::vector<int64_t> &result)
{
    for (size_t i = 0; i < expected.size(); i++) {
        EXPECT_EQ(result[i], expected[i]);
    }
}

void AssertBoolEquals(std::vector<bool> &expected, bool *result)
{
    for (size_t i = 0; i < expected.size(); i++) {
        EXPECT_EQ(result[i], expected[i]);
    }
}

BaseVector *CreateVector(DataType &dataType, int32_t rowCount, va_list &args)
{
    return DYNAMIC_TYPE_DISPATCH(CreateFlatVector, dataType.GetId(), rowCount, args);
}

vec::BaseVector *SliceVector(vec::BaseVector *vector, int32_t offset, int32_t length)
{
    using namespace omniruntime::type;
    if (vector->GetEncoding() != vec::OMNI_DICTIONARY) {
        return DYNAMIC_TYPE_DISPATCH(FlatVectorSlice, vector->GetTypeId(), vector, offset, length);
    } else {
        return DYNAMIC_TYPE_DISPATCH(DictionaryVectorSlice, vector->GetTypeId(), vector, offset, length);
    }
}

void SetValue(BaseVector *vector, int32_t index, void *value)
{
    DataTypeId typeId = vector->GetTypeId();
    if (value == nullptr) {
        if (typeId == OMNI_VARCHAR || typeId == OMNI_CHAR) {
            static_cast<Vector<LargeStringContainer<std::string_view>> *>(vector)->SetNull(index);
        } else {
            vector->SetNull(index);
        }
        return;
    }
    switch (typeId) {
        case OMNI_INT:
        case OMNI_DATE32:
            static_cast<Vector<int32_t> *>(vector)->SetValue(index, *static_cast<int32_t *>(value));
            break;
        case OMNI_SHORT:
            static_cast<Vector<int16_t> *>(vector)->SetValue(index, *static_cast<int16_t *>(value));
            break;
        case OMNI_LONG:
        case OMNI_TIMESTAMP:
        case OMNI_DECIMAL64:
            static_cast<Vector<int64_t> *>(vector)->SetValue(index, *static_cast<int64_t *>(value));
            break;
        case OMNI_DOUBLE:
            static_cast<Vector<double> *>(vector)->SetValue(index, *static_cast<double *>(value));
            break;
        case OMNI_BOOLEAN:
            static_cast<Vector<bool> *>(vector)->SetValue(index, *static_cast<bool *>(value));
            break;
        case OMNI_VARCHAR:
        case OMNI_CHAR: {
            std::string_view data = std::string_view(static_cast<std::string *>(value)->data(),
                static_cast<std::string *>(value)->length());
            static_cast<Vector<LargeStringContainer<std::string_view>> *>(vector)->SetValue(index, data);
            break;
        }
        case OMNI_DECIMAL128:
            static_cast<Vector<Decimal128> *>(vector)->SetValue(index, *static_cast<Decimal128 *>(value));
            break;
        default:
            LogError("No such data type %d", typeId);
            break;
    }
}

omniruntime::op::Operator *CreateTestOperator(omniruntime::op::OperatorFactory *operatorFactory)
{
    return operatorFactory->CreateOperator();
}

bool VecBatchMatchIgnoreOrder(vec::VectorBatch *resultBatch, vec::VectorBatch *expectedBatch, const double error)
{
    if (resultBatch->GetRowCount() != expectedBatch->GetRowCount()) {
        printf("Invalid row count. Expected=%d, actual=%d\n", expectedBatch->GetRowCount(), resultBatch->GetRowCount());
        PrintNotMatchBatches(resultBatch, expectedBatch);
        return false;
    }

    auto columnNumber = resultBatch->GetVectorCount();
    if (columnNumber != expectedBatch->GetVectorCount()) {
        printf("Invalid vector count. Expected=%d, actual=%d\n", expectedBatch->GetVectorCount(),
            resultBatch->GetVectorCount());
        PrintNotMatchBatches(resultBatch, expectedBatch);
        return false;
    }

    for (int32_t i = 0; i < columnNumber; ++i) {
        if (!ColumnMatchIgnoreOrder(resultBatch->Get(i), expectedBatch->Get(i), error)) {
            printf("Vector %d not matched\n", i);
            PrintNotMatchBatches(resultBatch, expectedBatch);
            return false;
        }
    }

    return true;
}

VectorBatch *DuplicateVectorBatch(VectorBatch *input)
{
    auto vecCount = input->GetVectorCount();
    auto rowCount = input->GetRowCount();
    auto duplication = new VectorBatch(rowCount);
    for (int32_t i = 0; i < vecCount; i++) {
        duplication->Append(SliceVector(input->Get(i), 0, rowCount));
    }
    return duplication;
}

void FreeVecBatches(VectorBatch **vecBatches, int32_t vecBatchCount)
{
    for (int32_t i = 0; i < vecBatchCount; ++i) {
        VectorHelper::FreeVecBatch(vecBatches[i]);
    }
    delete[] vecBatches;
}

void AssertDictionaryVectorShortEquals(BaseVector *vector, int16_t *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        ASSERT_EQ(static_cast<Vector<DictionaryContainer<int16_t>> *>(vector)->GetValue(i), values[i]);
    }
}

void AssertDictionaryVectorIntEquals(BaseVector *vector, int32_t *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        ASSERT_EQ(static_cast<Vector<DictionaryContainer<int32_t>> *>(vector)->GetValue(i), values[i]);
    }
}

void AssertDictionaryVectorLongEquals(BaseVector *vector, int64_t *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        ASSERT_EQ(static_cast<Vector<DictionaryContainer<int64_t>> *>(vector)->GetValue(i), values[i]);
    }
}

void AssertDictionaryVectorBooleanEquals(BaseVector *vector, bool *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        ASSERT_EQ(static_cast<Vector<DictionaryContainer<bool>> *>(vector)->GetValue(i), values[i]);
    }
}

void AssertDictionaryVectorDoubleEquals(BaseVector *vector, double *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        EXPECT_TRUE(std::fabs(static_cast<Vector<DictionaryContainer<bool>> *>(vector)->GetValue(i) - values[i]) <=
            DBL_EPSILON);
    }
}

void AssertDictionaryVectorVarcharEquals(BaseVector *vector, std::string *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        using DictionaryVarcharVector = Vector<DictionaryContainer<std::string_view, LargeStringContainer>>;
        std::string_view value = static_cast<DictionaryVarcharVector *>(vector)->GetValue(i);
        std::string actual(value.data(), value.length());
        ASSERT_EQ(actual, values[i]);
    }
}

void AssertDictionaryVectorDecimal128Equals(BaseVector *vector, Decimal128 *values)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        ASSERT_EQ(static_cast<Vector<DictionaryContainer<Decimal128>> *>(vector)->GetValue(i), values[i]);
    }
}

void AssertDoubleVectorEquals(BaseVector *vector, double *expectedValues)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        EXPECT_TRUE(std::fabs(static_cast<Vector<double> *>(vector)->GetValue(i) - expectedValues[i]) <= DBL_EPSILON);
    }
}

void AssertVarcharVectorEquals(BaseVector *vector, std::string *expectedValues)
{
    for (int32_t i = 0; i < vector->GetSize(); i++) {
        if (vector->IsNull(i)) {
            continue;
        }
        std::string_view value = static_cast<Vector<LargeStringContainer<std::string_view>> *>(vector)->GetValue(i);
        std::string result(value.data(), value.length());
        EXPECT_EQ(result, expectedValues[i]);
    }
}

void AssertDictionaryVectorEquals(BaseVector *vector, va_list &args)
{
    switch (vector->GetTypeId()) {
        case omniruntime::type::OMNI_SHORT:
            AssertDictionaryVectorShortEquals(vector, va_arg(args, int16_t *));
            break;
        case omniruntime::type::OMNI_INT:
        case omniruntime::type::OMNI_DATE32:
            AssertDictionaryVectorIntEquals(vector, va_arg(args, int32_t *));
            break;
        case omniruntime::type::OMNI_LONG:
        case omniruntime::type::OMNI_TIMESTAMP:
        case omniruntime::type::OMNI_DECIMAL64:
            AssertDictionaryVectorLongEquals(vector, va_arg(args, int64_t *));
            break;
        case omniruntime::type::OMNI_BOOLEAN:
            AssertDictionaryVectorBooleanEquals(vector, va_arg(args, bool *));
            break;
        case omniruntime::type::OMNI_DOUBLE:
            AssertDictionaryVectorDoubleEquals(vector, va_arg(args, double *));
            break;
        case omniruntime::type::OMNI_VARCHAR:
        case omniruntime::type::OMNI_CHAR:
            AssertDictionaryVectorVarcharEquals(vector, va_arg(args, std::string *));
            break;
        case omniruntime::type::OMNI_DECIMAL128:
            AssertDictionaryVectorDecimal128Equals(vector, va_arg(args, Decimal128 *));
            break;
        default:
            std::cerr << "unsupported type:" << vector->GetTypeId() << std::endl;
            break;
    }
}

std::vector<std::shared_ptr<vec::BaseVector>> CreateVectors(const type::DataTypes &types, int32_t rowCount, ...)
{
    int32_t typesCount = types.GetSize();
    std::vector<std::shared_ptr<BaseVector>> vectors;
    va_list args;
    va_start(args, rowCount);
    for (int32_t i = 0; i < typesCount; i++) {
        auto &type = types.GetType(i);
        vectors.push_back(std::shared_ptr<BaseVector>(CreateVector(*type, rowCount, args)));
    }
    va_end(args);
    return vectors;
}

void AssertVecBatchEquals(VectorBatch *vectorBatch, int32_t expectedVecCount, int32_t expectedRowCount, ...)
{
    int32_t vectorCount = vectorBatch->GetVectorCount();
    int32_t rowCount = vectorBatch->GetRowCount();
    EXPECT_EQ(vectorCount, expectedVecCount);
    EXPECT_EQ(rowCount, expectedRowCount);

    va_list args;
    va_start(args, expectedRowCount);
    for (int32_t i = 0; i < vectorCount; i++) {
        BaseVector *vector = vectorBatch->Get(i);
        EXPECT_EQ(vector->GetSize(), expectedRowCount);
        if (vector->GetEncoding() == OMNI_DICTIONARY) {
            AssertDictionaryVectorEquals(vector, args);
            break;
        }
        DataTypeId dataTypeId = vectorBatch->Get(i)->GetTypeId();
        switch (dataTypeId) {
            case omniruntime::type::OMNI_INT:
            case omniruntime::type::OMNI_DATE32:
                AssertVectorEquals<int32_t>(vector, va_arg(args, int32_t *));
                break;
            case omniruntime::type::OMNI_SHORT:
                AssertVectorEquals<int16_t>(vector, va_arg(args, int16_t *));
                break;
            case omniruntime::type::OMNI_LONG:
            case omniruntime::type::OMNI_TIMESTAMP:
            case omniruntime::type::OMNI_DECIMAL64:
                AssertVectorEquals<int64_t>(vector, va_arg(args, int64_t *));
                break;
            case omniruntime::type::OMNI_DOUBLE:
                AssertDoubleVectorEquals(vector, va_arg(args, double *));
                break;
            case omniruntime::type::OMNI_BOOLEAN:
                AssertVectorEquals<bool>(vector, va_arg(args, bool *));
                break;
            case omniruntime::type::OMNI_DECIMAL128:
                AssertVectorEquals<Decimal128>(vector, va_arg(args, Decimal128 *));
                break;
            case omniruntime::type::OMNI_VARCHAR:
            case omniruntime::type::OMNI_CHAR:
                AssertVarcharVectorEquals(vector, va_arg(args, std::string *));
                break;
            default:
                std::cerr << "Unsupported type : " << dataTypeId << std::endl;
                break;
        }
    }
    va_end(args);
}

BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, ...)
{
    va_list args;
    va_start(args, idsCount);
    auto dictionary = std::unique_ptr<BaseVector>(CreateVector(dataType, rowCount, args));
    va_end(args);
    return DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary.get(), ids, idsCount);
}

BaseVector *CreateVarcharVector(std::string *values, int32_t length)
{
    using VarcharVector = Vector<LargeStringContainer<std::string_view>>;

    VarcharVector *vector = new VarcharVector(length);
    for (int32_t i = 0; i < length; i++) {
        std::string_view value(values[i].data(), values[i].length());
        vector->SetValue(i, value);
    }
    return vector;
}

FuncExpr *GetFuncExpr(const std::string &funcName, std::vector<Expr *> args, DataTypePtr returnType)
{
    std::vector<DataTypeId> argTypes(args.size());
    std::transform(args.begin(), args.end(), argTypes.begin(),
        [](Expr *expr) -> DataTypeId { return expr->GetReturnTypeId(); });
    auto signature = FunctionSignature(funcName, argTypes, returnType->GetId());
    auto function = FunctionRegistry::LookupFunction(&signature);
    if (function != nullptr) {
        return new FuncExpr(funcName, args, returnType, function);
    }
    return nullptr;
}

std::string GenerateSpillPath()
{
    char *dirName = get_current_dir_name();
    std::string result = std::string(dirName) + std::string("/") + std::to_string(time(nullptr));
    free(dirName);
    return result;
}

int32_t *MakeInts(int32_t size, int32_t start)
{
    if (size > 0) {
        auto *arr = new int32_t[size];
        int32_t idx = 0;
        for (int32_t i = start; i < start + size; i++) {
            arr[idx++] = i;
        }
        return arr;
    } else {
        return nullptr;
    }
}

int8_t *MakeBytes(int32_t size, int32_t start)
{
    if (size > 0) {
        auto *arr = new int8_t[size];
        int32_t idx = 0;
        for (int32_t i = start; i < start + size; i++) {
            arr[idx++] = i;
        }
        return arr;
    } else {
        return nullptr;
    }
}

int16_t *MakeShorts(int32_t size, int32_t start)
{
    if (size > 0) {
        auto *arr = new int16_t[size];
        int32_t idx = 0;
        for (int32_t i = start; i < start + size; i++) {
            arr[idx++] = i;
        }
        return arr;
    } else {
        return nullptr;
    }
}

int64_t *MakeDecimals(int32_t size, int32_t start)
{
    if (size > 0) {
        const int32_t INDEX_FACTOR = 2;
        auto *arr = new int64_t[size * 2];
        int32_t idx = 0;
        for (int64_t i = start; i < start + size; i++) {
            if (i >= 0) {
                arr[INDEX_FACTOR * idx] = i;
                arr[INDEX_FACTOR * idx + 1] = 0;
            } else {
                arr[INDEX_FACTOR * idx] = i;
                arr[INDEX_FACTOR * idx + 1] = -1;
            }
            idx++;
        }
        return arr;
    } else {
        return nullptr;
    }
}

int64_t *MakeLongs(int32_t size, int64_t start)
{
    if (size > 0) {
        auto *arr = new int64_t[size];
        int32_t idx = 0;
        for (int64_t i = start; i < start + size; i++) {
            arr[idx++] = i;
        }
        return arr;
    } else {
        return nullptr;
    }
}

double *MakeDoubles(int32_t size, double start)
{
    if (size > 0) {
        auto *arr = new double[size];
        int32_t idx = 0;
        for (double i = start; i < start + size; i++) {
            arr[idx++] = i;
        }
        return arr;
    } else {
        return nullptr;
    }
}

VectorBatch *CreateEmptyVectorBatch(const DataTypes &dataTypes)
{
    auto *vectorBatch = new VectorBatch(0);
    auto *dataTypeIds = const_cast<int32_t *>(dataTypes.GetIds());
    auto vectorCnt = dataTypes.GetSize();
    BaseVector *vectors[vectorCnt];
    for (int32_t i = 0; i < vectorCnt; ++i) {
        vectors[i] = VectorHelper::CreateVector(OMNI_FLAT, dataTypeIds[i], 0);
        vectorBatch->Append(std::move(vectors[i]));
    }
    return vectorBatch;
}

int32_t DecodeAddFlag(int32_t resultCode)
{
    return resultCode >> 16;
}

int32_t DecodeFetchFlag(int32_t resultCode)
{
    return resultCode & SHRT_MAX;
}

bool CompareArrayUnorderedRows(BaseVector *resVec, BaseVector *dstVec, const double error) {
    auto resultVec = dynamic_cast<ArrayVector *>(resVec);
    auto expectedVec = dynamic_cast<ArrayVector *>(dstVec);
    if (!resultVec || !expectedVec) {
        throw omniruntime::exception::OmniException("RUNTIME_ERROR", "ArrayVector dynamic_cast failed!");
    }
    if (resultVec->vec::BaseVector::GetSize() != expectedVec->vec::BaseVector::GetSize()) {
        throw omniruntime::exception::OmniException("RUNTIME_ERROR", "Vector size does not match!");
    }

    for (int row = 0; row < resultVec->vec::BaseVector::GetSize(); row++) {
        if (resultVec->IsNull(row) != expectedVec->IsNull(row)) {
            return false;
        }
        if (resultVec->IsNull(row)) {
            continue;
        }

        int32_t resSize = resultVec->GetSize(row);
        int32_t expSize = expectedVec->GetSize(row);
        if (resSize != expSize) {
            return false;
        }

        int32_t resOffset = resultVec->GetOffset(row);
        int32_t expOffset = expectedVec->GetOffset(row);
        auto resSubArray = resultVec->GetElementVector()->Slice(resOffset, resSize, false);
        auto expSubArray = expectedVec->GetElementVector()->Slice(expOffset, expSize, false);
        bool match = ColumnMatchIgnoreOrder(resSubArray, expSubArray, error);

        delete resSubArray;
        delete expSubArray;
        if (!match) {
            return false;
        }
    }
    return true;
}

bool CompareVarcharUnorderedRows(BaseVector *resultVector, BaseVector *expectedVector, const double error)
{
    std::multiset<std::string_view> resRows;
    std::multiset<std::string_view> expectedRows;
    size_t resNullCount = 0;
    size_t expNullCount = 0;
    for (int32_t i = 0; i < resultVector->GetSize(); ++i) {
        if (resultVector->GetEncoding() == OMNI_DICTIONARY) {
            auto leftVector = reinterpret_cast<Vector<DictionaryContainer<std::string_view>> *>(resultVector);
            if (leftVector->IsNull(i)) {
                resNullCount++;
            } else {
                resRows.emplace(leftVector->GetValue(i));
            }
        } else {
            auto leftVector = static_cast<Vector<LargeStringContainer<std::string_view>> *>(resultVector);
            if (leftVector->IsNull(i)) {
                resNullCount++;
            } else {
                resRows.emplace(leftVector->GetValue(i));
            }
        }

        if (expectedVector->GetEncoding() == OMNI_DICTIONARY) {
            auto rightVector = reinterpret_cast<Vector<DictionaryContainer<std::string_view>> *>(expectedVector);
            if (rightVector->IsNull(i)) {
                expNullCount++;
            } else {
                expectedRows.emplace(rightVector->GetValue(i));
            }
        } else {
            auto rightVector = static_cast<Vector<LargeStringContainer<std::string_view>> *>(expectedVector);
            if (rightVector->IsNull(i)) {
                expNullCount++;
            } else {
                expectedRows.emplace(rightVector->GetValue(i));
            }
        }
    }

    if (resNullCount != expNullCount) {
        return false;
    }

    if (resRows.size() != expectedRows.size()) {
        return false;
    }

    auto it1 = resRows.begin();
    auto it2 = expectedRows.begin();
    for (; it1 != resRows.end(); ++it1, ++it2) {
        if (*it1 != *it2) {
            return false;
        }
    }

    return true;
}

template <typename D, typename V>
bool CompareUnorderedRows(BaseVector *resultVector, BaseVector *expectedVector, const double error)
{
    std::multiset<D> resRows;
    std::multiset<D> expectedRows;
    size_t resNullCount = 0;
    size_t expNullCount = 0;
    for (int32_t i = 0; i < resultVector->GetSize(); ++i) {
        if (resultVector->GetEncoding() == OMNI_DICTIONARY) {
            auto leftVector = reinterpret_cast<Vector<DictionaryContainer<V>> *>(resultVector);
            if (leftVector->IsNull(i)) {
                resNullCount++;
            } else {
                resRows.emplace(leftVector->GetValue(i));
            }
        } else {
            auto leftVector = static_cast<Vector<V> *>(resultVector);
            if (leftVector->IsNull(i)) {
                resNullCount++;
            } else {
                resRows.emplace(leftVector->GetValue(i));
            }
        }

        if (expectedVector->GetEncoding() == OMNI_DICTIONARY) {
            auto rightVector = reinterpret_cast<Vector<DictionaryContainer<V>> *>(expectedVector);
            if (rightVector->IsNull(i)) {
                expNullCount++;
            } else {
                expectedRows.emplace(rightVector->GetValue(i));
            }
        } else {
            auto rightVector = static_cast<Vector<V> *>(expectedVector);
            if (rightVector->IsNull(i)) {
                expNullCount++;
            } else {
                expectedRows.emplace(rightVector->GetValue(i));
            }
        }
    }

    if (resNullCount != expNullCount) {
        return false;
    }

    if (resRows.size() != expectedRows.size()) {
        return false;
    }

    auto it1 = resRows.begin();
    auto it2 = expectedRows.begin();
    for (; it1 != resRows.end(); ++it1, ++it2) {
        if constexpr (std::is_same_v<D, double>) {
            if (fabs(*it1 - *it2) > error) {
                return false;
            }
        } else if constexpr (std::is_same_v<D, Decimal128>) {
            Decimal128Wrapper left(*it1);
            Decimal128Wrapper right(*it2);
            if (left.Subtract(right).Abs() > Decimal128Wrapper(static_cast<int64_t>(error))) {
                return false;
            }
        } else if constexpr (std::is_same_v<D, std::string_view>) {
            if (*it1 != *it2) {
                return false;
            }
        } else {
            if (abs(*it1 - *it2) > static_cast<D>(error)) {
                return false;
            }
        }
    }

    return true;
}

bool ColumnMatchIgnoreOrder(BaseVector *resultVector, BaseVector *expectedVector, const double error)
{
    bool isMatched = true;
    switch (expectedVector->GetTypeId()) {
        case OMNI_INT:
        case OMNI_DATE32: {
            isMatched = CompareUnorderedRows<int32_t, int32_t>(resultVector, expectedVector, error);
            break;
        }
        case OMNI_SHORT: {
            isMatched = CompareUnorderedRows<int16_t, int16_t>(resultVector, expectedVector, error);
            break;
        }
        case OMNI_DOUBLE: {
            isMatched = CompareUnorderedRows<double, double>(resultVector, expectedVector, error);
            break;
        }
        case OMNI_LONG:
        case OMNI_TIMESTAMP:
        case OMNI_DECIMAL64: {
            isMatched = CompareUnorderedRows<int64_t, int64_t>(resultVector, expectedVector, error);
            break;
        }
        case OMNI_BOOLEAN: {
            isMatched = CompareUnorderedRows<bool, bool>(resultVector, expectedVector, error);
            break;
        }
        case OMNI_DECIMAL128: {
            isMatched = CompareUnorderedRows<Decimal128, Decimal128>(resultVector, expectedVector, error);
            break;
        }
        case OMNI_CHAR:
        case OMNI_VARCHAR: {
            isMatched = CompareVarcharUnorderedRows(resultVector, expectedVector, error);
            break;
        }
        case OMNI_CONTAINER: {
            isMatched = CompareUnorderedRowsContainer(static_cast<ContainerVector *>(resultVector),
                static_cast<ContainerVector *>(expectedVector), error);
            break;
        }
        case OMNI_ARRAY: {
            isMatched = CompareArrayUnorderedRows(resultVector, expectedVector, error);
            break;
        }
        default: {
            return false;
        }
    }
    return isMatched;
}

bool CompareUnorderedRowsContainer(ContainerVector *resultContainerVector, ContainerVector *expectedContainerVector,
    const double error)
{
    int32_t vecCount = expectedContainerVector->GetVectorCount();
    for (int32_t vecIdx = 0; vecIdx < vecCount; vecIdx++) {
        auto resultVector = reinterpret_cast<BaseVector *>(resultContainerVector->GetValue(vecIdx));
        auto expectedVector = reinterpret_cast<BaseVector *>(expectedContainerVector->GetValue(vecIdx));
        auto result = ColumnMatchIgnoreOrder(resultVector, expectedVector, error);
        if (!result) {
            return false;
        }
    }
    return true;
}
}