/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/Regexps.h>
#include <Interpreters/Context.h>
#include <Common/FunctionDocumentation.h>

namespace DB
{
namespace ErrorCodes
{
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
    extern const int ILLEGAL_COLUMN;
    extern const int INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE;
}
}

using namespace DB;

namespace local_engine
{
using SparkRegexp = OptimizedRegularExpression;
namespace
{
    class FunctionRegexpExtractAllSpark : public IFunction
    {
    public:
        using Pos = const char *;

        static constexpr auto name = "regexpExtractAllSpark";
        static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionRegexpExtractAllSpark>(); }

        String getName() const override { return name; }

        bool isVariadic() const override { return true; }
        size_t getNumberOfArguments() const override { return 0; }

        bool useDefaultImplementationForConstants() const override { return true; }
        ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }

        bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }

        DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
        {
            if (arguments.size() != 2 && arguments.size() != 3)
                throw Exception(
                    ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
                    "Number of arguments for function {} doesn't match: passed {}",
                    getName(),
                    arguments.size());

            FunctionArgumentDescriptors args{
                {"haystack", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isString), nullptr, "String"},
                {"pattern", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isString), isColumnConst, "const String"},
            };

            if (arguments.size() == 3)
                args.emplace_back(FunctionArgumentDescriptor{"index", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isInteger), nullptr, "Integer"});

            validateFunctionArguments(*this, arguments, args);

            return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>());
        }

        ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
        {
            const ColumnPtr column = arguments[0].column;
            const ColumnPtr column_pattern = arguments[1].column;
            const ColumnPtr column_index = arguments.size() > 2 ? arguments[2].column : nullptr;

            /// Check if the second argument is const column
            const ColumnConst * col_pattern = typeid_cast<const ColumnConst *>(column_pattern.get());
            if (!col_pattern)
                throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be constant string", getName());

            /// Check if the first argument is string column(const or not)
            const ColumnString * col = nullptr;
            const ColumnConst * col_const = typeid_cast<const ColumnConst *>(column.get());
            if (col_const)
                col = typeid_cast<const ColumnString *>(&col_const->getDataColumn());
            else
                col = typeid_cast<const ColumnString *>(column.get());
            if (!col)
                throw Exception(
                    ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", arguments[0].column->getName(), getName());

            auto col_res = ColumnArray::create(ColumnString::create());
            ColumnString & res_strings = typeid_cast<ColumnString &>(col_res->getData());
            ColumnArray::Offsets & res_offsets = col_res->getOffsets();
            ColumnString::Chars & res_strings_chars = res_strings.getChars();
            ColumnString::Offsets & res_strings_offsets = res_strings.getOffsets();

            if (col_const)
                constantVector(
                    col_const->getValue<String>(),
                    col_pattern->getValue<String>(),
                    column_index,
                    res_offsets,
                    res_strings_chars,
                    res_strings_offsets);
            else if (!column_index || isColumnConst(*column_index))
            {
                const auto * col_const_index = typeid_cast<const ColumnConst *>(column_index.get());
                ssize_t index = !col_const_index ? 1 : col_const_index->getInt(0);
                vectorConstant(
                    col->getChars(),
                    col->getOffsets(),
                    col_pattern->getValue<String>(),
                    index,
                    res_offsets,
                    res_strings_chars,
                    res_strings_offsets);
            }
            else
                vectorVector(
                    col->getChars(),
                    col->getOffsets(),
                    col_pattern->getValue<String>(),
                    column_index,
                    res_offsets,
                    res_strings_chars,
                    res_strings_offsets);

            return col_res;
        }

    private:
        static void saveMatchs(
            Pos start,
            Pos end,
            const SparkRegexp & regexp,
            OptimizedRegularExpression::MatchVec & matches,
            size_t match_index,
            ColumnArray::Offsets & res_offsets,
            ColumnString::Chars & res_strings_chars,
            ColumnString::Offsets & res_strings_offsets,
            size_t & res_offset,
            size_t & res_strings_offset)
        {
            size_t i = 0;
            Pos pos = start;
            while (pos < end)
            {
                regexp.match(pos, end - pos, matches, static_cast<unsigned>(match_index + 1));
                if (match_index >= matches.size())
                    break;

                /// Append matched segment into res_strings_chars
                const auto & match = matches[match_index];
                if (match.offset != std::string::npos)
                {
                    res_strings_chars.resize_exact(res_strings_offset + match.length + 1);
                    memcpySmallAllowReadWriteOverflow15(&res_strings_chars[res_strings_offset], pos + match.offset, match.length);
                    res_strings_offset += match.length;
                }
                else
                    res_strings_chars.resize_exact(res_strings_offset + 1);

                /// Update offsets of Column:String
                res_strings_chars[res_strings_offset] = 0;
                ++res_strings_offset;
                res_strings_offsets.push_back(res_strings_offset);
                ++i;

                pos += matches[0].offset + matches[0].length;
            }

            /// Update offsets of Column:Array(String)
            res_offset += i;
            res_offsets.push_back(res_offset);
        }

        static void vectorConstant(
            const ColumnString::Chars & data,
            const ColumnString::Offsets & offsets,
            const std::string & pattern,
            ssize_t index,
            ColumnArray::Offsets & res_offsets,
            ColumnString::Chars & res_strings_chars,
            ColumnString::Offsets & res_strings_offsets)
        {
            const SparkRegexp regexp = Regexps::createRegexp<false, false, false>(pattern);
            unsigned capture = regexp.getNumberOfSubpatterns();
            if (index < 0 || index >= capture + 1)
                throw Exception(
                    ErrorCodes::INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE,
                    "Index value {} is out of range, should be in [0, {})",
                    index,
                    capture + 1);

            OptimizedRegularExpression::MatchVec matches;
            matches.reserve(index + 1);

            res_offsets.reserve_exact(offsets.size());
            res_strings_chars.reserve_exact(data.size() / 3);
            res_strings_offsets.reserve_exact(offsets.size() * 2);

            size_t res_offset = 0;
            size_t res_strings_offset = 0;
            size_t prev_offset = 0;
            for (size_t cur_offset : offsets)
            {
                Pos start = reinterpret_cast<const char *>(&data[prev_offset]);
                Pos end = start + (cur_offset - prev_offset - 1);
                saveMatchs(
                    start,
                    end,
                    regexp,
                    matches,
                    index,
                    res_offsets,
                    res_strings_chars,
                    res_strings_offsets,
                    res_offset,
                    res_strings_offset);

                prev_offset = cur_offset;
            }
        }

        static void vectorVector(
            const ColumnString::Chars & data,
            const ColumnString::Offsets & offsets,
            const std::string & pattern,
            const ColumnPtr & column_index,
            ColumnArray::Offsets & res_offsets,
            ColumnString::Chars & res_strings_chars,
            ColumnString::Offsets & res_strings_offsets)
        {
            const SparkRegexp regexp = Regexps::createRegexp<false, false, false>(pattern);
            unsigned capture = regexp.getNumberOfSubpatterns();

            OptimizedRegularExpression::MatchVec matches;
            matches.reserve(capture + 1);

            res_offsets.reserve_exact(offsets.size());
            res_strings_chars.reserve_exact(data.size() / 3);
            res_strings_offsets.reserve_exact(offsets.size() * 2);

            size_t res_offset = 0;
            size_t res_strings_offset = 0;
            size_t prev_offset = 0;
            for (size_t i = 0; i < offsets.size(); ++i)
            {
                ssize_t index = column_index->getInt(i);
                if (index < 0 || index >= capture + 1)
                    throw Exception(
                        ErrorCodes::INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE,
                        "Index value {} is out of range, should be in [0, {})",
                        index,
                        capture + 1);

                size_t cur_offset = offsets[i];
                Pos start = reinterpret_cast<const char *>(&data[prev_offset]);
                Pos end = start + (cur_offset - prev_offset - 1);
                saveMatchs(
                    start,
                    end,
                    regexp,
                    matches,
                    index,
                    res_offsets,
                    res_strings_chars,
                    res_strings_offsets,
                    res_offset,
                    res_strings_offset);

                prev_offset = cur_offset;
            }
        }

        static void constantVector(
            const std::string & str,
            const std::string & pattern,
            const ColumnPtr & column_index,
            ColumnArray::Offsets & res_offsets,
            ColumnString::Chars & res_strings_chars,
            ColumnString::Offsets & res_strings_offsets)
        {
            const SparkRegexp regexp = Regexps::createRegexp<false, false, false>(pattern);
            unsigned capture = regexp.getNumberOfSubpatterns();

            /// Copy data into padded array to be able to use memcpySmallAllowReadWriteOverflow15.
            ColumnString::Chars padded_str;
            padded_str.insert(str.begin(), str.end());

            Pos start = reinterpret_cast<Pos>(padded_str.data());
            Pos pos = start;
            Pos end = start + padded_str.size();
            Pos prev_pos = nullptr;
            std::vector<OptimizedRegularExpression::MatchVec> matches_groups;
            while (pos < end)
            {
                OptimizedRegularExpression::MatchVec matches;
                matches.reserve(capture + 1);
                regexp.match(pos, end - pos, matches, static_cast<unsigned>(capture + 1));
                if (capture + 1 > matches.size())
                    break;


                /// Make all the offsets based on start
                for (auto & match : matches)
                    if (match.offset != std::string::npos)
                        match.offset += pos - start;

                prev_pos = pos;
                pos = start + matches[0].offset + matches[0].length;
                /// Avoid dead loop caused by empty captured string
                if (pos == prev_pos)
                    ++pos;

                matches_groups.emplace_back(std::move(matches));
            }

            size_t rows = column_index->size();
            res_offsets.reserve_exact(rows);
            res_strings_chars.reserve_exact(rows * str.size() / 3);
            res_strings_offsets.reserve_exact(rows * 2);

            size_t res_offset = 0;
            size_t res_strings_offset = 0;
            for (size_t row_i = 0; row_i < rows; ++row_i)
            {
                ssize_t index = column_index->getInt(row_i);
                if (index < 0 || index >= capture + 1)
                    throw Exception(
                        ErrorCodes::INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE,
                        "Index value {} is out of range, should be in [0, {})",
                        index,
                        capture + 1);

                for (auto & matches : matches_groups)
                {
                    const auto & match = matches[index];

                    /// Append matched segment into res_strings_chars
                    if (match.offset != std::string::npos)
                    {
                        res_strings_chars.resize_exact(res_strings_offset + match.length + 1);
                        memcpySmallAllowReadWriteOverflow15(&res_strings_chars[res_strings_offset], start + match.offset, match.length);
                        res_strings_offset += match.length;
                    }
                    else
                        res_strings_chars.resize_exact(res_strings_offset + 1);

                    /// Update offsets of Column:String
                    res_strings_chars[res_strings_offset] = 0;
                    ++res_strings_offset;
                    res_strings_offsets.push_back(res_strings_offset);
                }

                /// Update offsets of Column:Array(String)
                res_offset += matches_groups.size();
                res_offsets.push_back(res_offset);
            }
        }
    };
}

REGISTER_FUNCTION(RegexpExtractAllSpark)
{
    factory.registerFunction<FunctionRegexpExtractAllSpark>(
        FunctionDocumentation{.description = R"(Extracts all the fragments of a string that matches the regexp pattern and corresponds to the regex group index.)"});
}

}