/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/ActionsDAG.h>
#include <Parser/AggregateFunctionParser.h>
#include <Parser/aggregate_function_parser/PercentileParserBase.h>
#include <substrait/algebra.pb.h>
#include <Common/CHUtil.h>

namespace DB
{
namespace ErrorCodes
{
    extern const int BAD_ARGUMENTS;
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
using namespace DB;
void PercentileParserBase::assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const
{
    if (size != expect)
        throw Exception(
            DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
            "Function {} in phase {} requires exactly {} arguments but got {} arguments",
            getName(),
            magic_enum::enum_name(phase),
            expect,
            size);
}

const substrait::Expression::Literal &
PercentileParserBase::assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const
{
    if (!expr.has_literal())
        throw Exception(
            DB::ErrorCodes::BAD_ARGUMENTS,
            "The argument of function {} in phase {} must be literal, but is {}",
            getName(),
            magic_enum::enum_name(phase),
            expr.DebugString());
    return expr.literal();
}

String PercentileParserBase::getCHFunctionName(const CommonFunctionInfo & func_info) const
{
    const auto & output_type = func_info.output_type;
    return output_type.has_list() ? getCHPluralName() : getCHSingularName();
}

String PercentileParserBase::getCHFunctionName(DB::DataTypes & types) const
{
    /// Always invoked during second stage
    assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), expectedTupleElementsNumberInSecondStage());

    auto type = removeNullable(types[PERCENTAGE_INDEX]);
    types.resize(1);

    if (getName() == "percentile")
    {
        /// Corresponding CH function requires two arguments: quantileExactWeightedInterpolated(xxx)(col, weight)
        types.push_back(std::make_shared<DataTypeUInt64>());
    }

    return isArray(type) ? getCHPluralName() : getCHSingularName();
}

DB::Array PercentileParserBase::parseFunctionParameters(
    const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes, DB::ActionsDAG & actions_dag) const
{
    if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
        || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT || func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
    {
        Array params;
        const auto & arguments = func_info.arguments;
        assertArgumentsSize(func_info.phase, arguments.size(), expectedArgumentsNumberInFirstStage());

        auto param_indexes = getArgumentsThatAreParameters();
        for (auto idx : param_indexes)
        {
            const auto & expr = arguments[idx].value();
            const auto & literal = assertAndGetLiteral(func_info.phase, expr);
            auto [type, field] = parseLiteral(literal);

            if (idx == PERCENTAGE_INDEX && isArray(removeNullable(type)))
            {
                /// Multiple percentages for quantilesXXX
                const Array & percentags = field.safeGet<Array>();
                for (const auto & percentage : percentags)
                    params.emplace_back(percentage);
            }
            else
            {
                params.emplace_back(std::move(field));
            }
        }

        /// Collect arguments in substrait plan that are not CH parameters as CH arguments
        ActionsDAG::NodeRawConstPtrs new_arg_nodes;
        for (size_t i = 0; i < arg_nodes.size(); ++i)
        {
            if (std::find(param_indexes.begin(), param_indexes.end(), i) == param_indexes.end())
            {
                if (getName() == "percentile" && i == 2)
                {
                    /// In spark percentile(col, percentage, weight), the last argument weight is a signed integer
                    /// But CH requires weight as an unsigned integer
                    DataTypePtr dst_type = std::make_shared<DataTypeUInt64>();
                    if (arg_nodes[i]->result_type->isNullable())
                        dst_type = std::make_shared<DataTypeNullable>(dst_type);

                    arg_nodes[i] = ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, arg_nodes[i], dst_type);
                }

                new_arg_nodes.emplace_back(arg_nodes[i]);
            }
        }
        new_arg_nodes.swap(arg_nodes);

        return params;
    }
    else
    {
        assertArgumentsSize(func_info.phase, arg_nodes.size(), 1);
        const auto & result_type = arg_nodes[0]->result_type;
        const auto * aggregate_function_type = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(result_type.get());
        if (!aggregate_function_type)
            throw Exception(
                DB::ErrorCodes::BAD_ARGUMENTS,
                "The first argument type of function {} in phase {} must be AggregateFunction, but is {}",
                getName(),
                magic_enum::enum_name(func_info.phase),
                result_type->getName());

        return aggregate_function_type->getParameters();
    }
}
}