/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
 */

#include "common/common.h"
#include "common/vector_util.h"
#include "operator/window/window.h"

using namespace benchmark;
using namespace omniruntime::op;
using namespace omniruntime::vec;
using namespace omniruntime::type;

namespace om_benchmark {
class Window : public BaseOperatorFixture {
protected:
    OperatorFactory *createOperatorFactory(const State &state) override
    {
        std::vector<omniruntime::op::FunctionType> functions = WINDOW_FUNCTION[TestGroup(state)];
        std::vector<std::vector<int32_t>> sortOrders(SORT_CHANNELS[TestGroup(state)].size());
        for (unsigned int i = 0; i < SORT_CHANNELS[TestGroup(state)].size(); ++i) {
            sortOrders[i] = {1, 0};
        }
        std::vector<int32_t> outputChannels(INPUT_TYPES[TestGroup(state)].size());
        for (unsigned int i = 0; i < INPUT_TYPES[TestGroup(state)].size(); ++i) {
            outputChannels[i] = (int)i;
        }
        std::vector<DataTypePtr> resultTypes = { RESULT_TYPE[TestGroup(state)] };

        std::vector<int32_t> argumentChannels(ARGUMENT_CHANNELS[TestGroup(state)].size());
        for (unsigned int i = 0; i < ARGUMENT_CHANNELS[TestGroup(state)].size(); ++i) {
            argumentChannels[i] = ARGUMENT_CHANNELS[TestGroup(state)][i];
        }

        if (NumberOfPregroupedColumns(state) == 0) {
            return createFactory(INPUT_TYPES[TestGroup(state)], resultTypes, outputChannels, functions,
                PARTITION_CHANNELS[TestGroup(state)], {}, SORT_CHANNELS[TestGroup(state)], sortOrders, argumentChannels,
                0);
        } else if (NumberOfPregroupedColumns(state) < numberOfGroupColumns) {
            return createFactory(INPUT_TYPES[TestGroup(state)], resultTypes, outputChannels, functions,
                PARTITION_CHANNELS[TestGroup(state)], { 1 }, SORT_CHANNELS[TestGroup(state)], sortOrders,
                argumentChannels, 0);
        } else {
            return createFactory(INPUT_TYPES[TestGroup(state)], resultTypes, outputChannels, functions,
                PARTITION_CHANNELS[TestGroup(state)], { 0, 1 }, SORT_CHANNELS[TestGroup(state)], sortOrders,
                argumentChannels, (NumberOfPregroupedColumns(state) - numberOfGroupColumns));
        }
    }

    std::vector<VectorBatchSupplier> createVecBatch(const State &state) override
    {
        std::vector<VectorBatch *> vvb(totalPages);

        for (int i = 0; i < totalPages; ++i) {
            if (DictionaryBlocks(state)) {
                vvb[i] = CreateSequenceVectorBatchWithDictionaryVector(INPUT_TYPES[TestGroup(state)],
                    rowsPerPage);
            } else {
                vvb[i] = CreateSequenceVectorBatch(INPUT_TYPES[TestGroup(state)], rowsPerPage);
            }
        }

        return VectorBatchToVectorBatchSupplier(vvb);
    }

private:
    int64_t totalPages = 100;
    int rowsPerPage = 10000;
    int numberOfGroupColumns = 2;
    std::map<std::string, std::vector<int32_t>> PARTITION_CHANNELS = {
        { "group1", { 0, 1 } },       { "group2", { 0, 1, 2 } }, { "group3", { 0, 1, 2, 3, 4 } },
        { "group4", { 0, 1, 2, 3 } }, { "group5", { 0, 1 } },    { "group6", { 0, 1 } },
        { "group7", { 0, 1, 3, 4 } }
    };
    std::map<std::string, std::vector<DataTypePtr>> INPUT_TYPES = {
        { "group1", { LongType(), LongType(), LongType(), LongType() } },
        { "group2", { LongType(), LongType(), LongType(), LongType() } },
        { "group3",
          { VarcharType(50), VarcharType(50), VarcharType(50), VarcharType(50), IntType(), IntType(), LongType(),
            LongType() } },
        { "group4",
          { VarcharType(50), VarcharType(50), VarcharType(50), VarcharType(50), IntType(), IntType(), LongType() } },
        { "group5", { LongType(), IntType() } },
        { "group6", { LongType(), IntType() } },
        { "group7",
          { VarcharType(50), VarcharType(50), VarcharType(50), VarcharType(50), VarcharType(50), IntType(),
            LongType() } }
    };
    std::map<std::string, std::vector<omniruntime::op::FunctionType>> WINDOW_FUNCTION = {
        { "group1", { OMNI_WINDOW_TYPE_ROW_NUMBER } }, { "group2", { OMNI_AGGREGATION_TYPE_COUNT_COLUMN } },
        { "group3", { OMNI_AGGREGATION_TYPE_AVG } },   { "group4", { OMNI_WINDOW_TYPE_RANK } },
        { "group5", { OMNI_AGGREGATION_TYPE_AVG } },   { "group6", { OMNI_AGGREGATION_TYPE_AVG } },
        { "group7", { OMNI_AGGREGATION_TYPE_AVG } },
    };
    std::map<std::string, std::vector<int32_t>> ARGUMENT_CHANNELS = { { "group1", {} },    { "group2", { 2 } },
                                                                      { "group3", { 7 } }, { "group4", {} },
                                                                      { "group5", { 1 } }, { "group6", { 1 } },
                                                                      { "group7", { 6 } } };
    std::map<std::string, DataTypePtr> RESULT_TYPE = { { "group1", LongType() },   { "group2", LongType() },
                                                       { "group3", DoubleType() }, { "group4", LongType() },
                                                       { "group5", DoubleType() }, { "group6", DoubleType() },
                                                       { "group7", DoubleType() } };

    std::map<std::string, std::vector<int32_t>> SORT_CHANNELS = { { "group1", { 3 } }, { "group2", { 3 } },
                                                                  { "group3", {} },    { "group4", { 4, 5 } },
                                                                  { "group5", {} },    { "group6", {} },
                                                                  { "group7", {} } };

    OMNI_BENCHMARK_PARAM(std::string, TestGroup, "group1", "group2", "group3", "group4", "group5", "group6", "group7");
    OMNI_BENCHMARK_PARAM(bool, DictionaryBlocks, false, true);
    OMNI_BENCHMARK_PARAM(int32_t, NumberOfPregroupedColumns, 0, 1, 2);

    static OperatorFactory *createFactory(const std::vector<DataTypePtr> &sourceTypes,
        const std::vector<DataTypePtr> &resultTypes, std::vector<int32_t> &outputChannels,
        std::vector<omniruntime::op::FunctionType> &functions, std::vector<int32_t> &partitionChannels,
        std::vector<int32_t> preGroupedChannels, std::vector<int32_t> &sortChannels,
        std::vector<std::vector<int32_t>> &sortOrder, std::vector<int32_t> &argumentChannels,
        int preSortedChannelPrefix)
    {
        std::vector<int32_t> windowFunctions(functions.size());
        std::vector<int32_t> windowFrameTypesFields(functions.size());
        std::vector<int32_t> windowFrameStartTypesField(functions.size());
        std::vector<int32_t> windowFrameStartChannelsField(functions.size());
        std::vector<int32_t> windowFrameEndTypesField(functions.size());
        std::vector<int32_t> windowFrameEndChannelsField(functions.size());

        for (unsigned int i = 0; i < functions.size(); ++i) {
            windowFunctions[i] = functions[i];
            windowFrameTypesFields[i] = OMNI_FRAME_TYPE_RANGE;
            windowFrameStartTypesField[i] = OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING;
            windowFrameStartChannelsField[i] = -1;
            windowFrameEndTypesField[i] = OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING;
            windowFrameEndChannelsField[i] = -1;
        }
        std::vector<int32_t> asc(sortChannels.size());
        std::vector<int32_t> nullFirst(sortChannels.size());
        for (unsigned int i = 0; i < sortChannels.size(); ++i) {
            asc[i] = sortOrder.at(i)[0];
            nullFirst[i] = sortOrder.at(i)[1];
        }

        std::vector<DataTypePtr> allTypesVec;
        allTypesVec.insert(allTypesVec.end(), sourceTypes.begin(), sourceTypes.end());
        allTypesVec.insert(allTypesVec.end(), resultTypes.begin(), resultTypes.end());
        int32_t expectedPositions = 10;
        auto factory =
            new WindowOperatorFactory(DataTypes(sourceTypes), outputChannels.data(), (int32_t)outputChannels.size(),
            windowFunctions.data(), (int32_t)windowFunctions.size(), partitionChannels.data(),
            (int32_t)partitionChannels.size(), preGroupedChannels.data(), (int32_t)preGroupedChannels.size(),
            sortChannels.data(), asc.data(), nullFirst.data(), (int32_t)sortChannels.size(), preSortedChannelPrefix,
            expectedPositions, DataTypes(allTypesVec), argumentChannels.data(), (int32_t)argumentChannels.size(),
            windowFrameTypesFields.data(), windowFrameStartTypesField.data(), windowFrameStartChannelsField.data(),
            windowFrameEndTypesField.data(), windowFrameEndChannelsField.data(), true);
        factory->Init();
        return factory;
    }
};
OMNI_BENCHMARK_DECLARE_OPERATOR_DEFAULT(Window);
}