/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>

using namespace DB;

namespace DB
{
namespace ErrorCodes
{
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
class SparkFunctionArraysOverlap : public IFunction
{
public:
    static constexpr auto name = "sparkArraysOverlap";
    static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionArraysOverlap>(); }
    SparkFunctionArraysOverlap() = default;
    ~SparkFunctionArraysOverlap() override = default;
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
    size_t getNumberOfArguments() const override { return 2; }
    String getName() const override { return name; }
    bool useDefaultImplementationForConstants() const override { return true; }

    DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
    {
        auto data_type = std::make_shared<DataTypeUInt8>();
        return makeNullable(data_type);
    }
    
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
    {
        if (arguments.size() != 2)
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 arguments", getName());
        
        auto res = ColumnUInt8::create(input_rows_count, 0);
        auto null_map = ColumnUInt8::create(input_rows_count, 0);
        PaddedPODArray<UInt8> & res_data = res->getData();
        PaddedPODArray<UInt8> & null_map_data = null_map->getData();
        if (input_rows_count == 0)
            return ColumnNullable::create(std::move(res), std::move(null_map));
        
        const ColumnArray * array_col_1 = checkAndGetColumn<ColumnArray>(arguments[0].column.get());  
        const ColumnArray * array_col_2 = checkAndGetColumn<ColumnArray>(arguments[1].column.get());
        if (!array_col_1 || !array_col_2)
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st/2nd argument must be array type", getName());

        const ColumnArray::Offsets & array_offsets_1 = array_col_1->getOffsets();
        const ColumnArray::Offsets & array_offsets_2 = array_col_2->getOffsets();

        size_t current_offset_1 = 0, current_offset_2 = 0;
        size_t array_pos_1 = 0, array_pos_2 = 0;
        for (size_t i = 0; i < array_col_1->size(); ++i)
        {
            size_t array_size_1 = array_offsets_1[i] - current_offset_1;
            size_t array_size_2 = array_offsets_2[i] - current_offset_2;
            auto executeCompare = [&](const IColumn & col1, const IColumn & col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void
            {   
                for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j)
                {
                    for (size_t k = 0; k < array_size_2; ++k)
                    {
                        if ((null_map1 && null_map1->getElement(j + array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2)))
                        {
                            null_map_data[i] = 1;
                        }
                        else if (col1.compareAt(j + array_pos_1, k + array_pos_2, col2, -1) == 0)
                        {
                            res_data[i] = 1;
                            null_map_data[i] = 0;
                            break;
                        }
                    }
                }
            };
            if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable())
            {
                if (array_col_1->getData().isNullable() && array_col_2->getData().isNullable())
                {
                    const ColumnNullable * array_null_col_1 = assert_cast<const ColumnNullable *>(&array_col_1->getData());
                    const ColumnNullable * array_null_col_2 = assert_cast<const ColumnNullable *>(&array_col_2->getData());
                    executeCompare(array_null_col_1->getNestedColumn(), array_null_col_2->getNestedColumn(),
                        &array_null_col_1->getNullMapColumn(), &array_null_col_2->getNullMapColumn());
                }
                else if (array_col_1->getData().isNullable())
                {
                    const ColumnNullable * array_null_col_1 = assert_cast<const ColumnNullable *>(&array_col_1->getData());
                    executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData(), &array_null_col_1->getNullMapColumn(), nullptr);
                }
                else if (array_col_2->getData().isNullable())
                {
                    const ColumnNullable * array_null_col_2 = assert_cast<const ColumnNullable *>(&array_col_2->getData());
                    executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn(), nullptr, &array_null_col_2->getNullMapColumn());
                }
            }
            else if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType())
            {
                executeCompare(array_col_1->getData(), array_col_2->getData(), nullptr, nullptr);
            }

            current_offset_1 = array_offsets_1[i];
            current_offset_2 = array_offsets_2[i];
            array_pos_1 += array_size_1;
            array_pos_2 += array_size_2;
        }
        return ColumnNullable::create(std::move(res), std::move(null_map));
    }
};

REGISTER_FUNCTION(SparkArraysOverlap)
{
    factory.registerFunction<SparkFunctionArraysOverlap>();
}

}