/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "AggregateFunctionParser.h"

#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/ExpressionParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Parser/TypeParser.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>

namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_FUNCTION;
}
}

namespace local_engine
{
using namespace DB;
AggregateFunctionParser::AggregateFunctionParser(ParserContextPtr parser_context_) : parser_context(parser_context_)
{
    expression_parser = std::make_unique<ExpressionParser>(parser_context);
}

AggregateFunctionParser::~AggregateFunctionParser()
{
}

String AggregateFunctionParser::getUniqueName(const String & name) const
{
    return expression_parser->getUniqueName(name);
}

const DB::ActionsDAG::Node *
AggregateFunctionParser::addColumnToActionsDAG(DB::ActionsDAG & actions_dag, const DB::DataTypePtr & type, const DB::Field & field) const
{
    return expression_parser->addConstColumn(actions_dag, type, field);
}

const DB::ActionsDAG::Node * AggregateFunctionParser::toFunctionNode(
    DB::ActionsDAG & action_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
    return expression_parser->toFunctionNode(action_dag, func_name, args);
}

const DB::ActionsDAG::Node * AggregateFunctionParser::toFunctionNode(
    DB::ActionsDAG & action_dag, const String & func_name, const String & result_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
    return expression_parser->toFunctionNode(action_dag, func_name, args, result_name);
}

const DB::ActionsDAG::Node * AggregateFunctionParser::parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
    return expression_parser->parseExpression(actions_dag, rel);
}

std::pair<DataTypePtr, Field> AggregateFunctionParser::parseLiteral(const substrait::Expression_Literal & literal) const
{
    return LiteralParser::parse(literal);
}

DB::ActionsDAG::NodeRawConstPtrs
AggregateFunctionParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAG & actions_dag) const
{
    DB::ActionsDAG::NodeRawConstPtrs collected_args;
    for (const auto & arg : func_info.arguments)
    {
        auto arg_value = arg.value();
        const DB::ActionsDAG::Node * arg_node = parseExpression(actions_dag, arg_value);

        // If the aggregate result is required to be nullable, make all inputs be nullable at the first stage.
        auto required_output_type = DB::WhichDataType(TypeParser::parseType(func_info.output_type));
        if (required_output_type.isNullable()
            && (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
                || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT)
            && !arg_node->result_type->isNullable())
        {
            DB::ActionsDAG::NodeRawConstPtrs args;
            args.emplace_back(arg_node);
            const auto * node = toFunctionNode(actions_dag, "toNullable", args);
            actions_dag.addOrReplaceInOutputs(*node);
            arg_node = node;
        }

        collected_args.push_back(arg_node);
    }

    if (func_info.has_filter)
    {
        // With `If` combinator, the function take one more argument which refers to the condition.
        const auto * action_node = parseExpression(actions_dag, func_info.filter);
        collected_args.emplace_back(action_node);
    }
    return collected_args;
}

std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
    const CommonFunctionInfo & func_info, const String & ch_func_name, const DB::DataTypes & argument_types) const
{
    auto get_aggregate_function
        = [](const String & name, const DB::DataTypes & argument_types, const DB::Array & parameters) -> DB::AggregateFunctionPtr
    {
        DB::AggregateFunctionProperties properties;
        auto func = RelParser::getAggregateFunction(name, argument_types, properties, parameters);
        if (!func)
            throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown aggregate function {}", name);

        return func;
    };

    String combinator_function_name = ch_func_name;
    DB::DataTypes combinator_argument_types = argument_types;

    if (func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
        && func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT)
    {
        if (argument_types.size() != 1)
            throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Only support one argument aggregate function in phase {}", func_info.phase);

        // Add a check here for safty.
        if (func_info.has_filter)
            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Apply filter in phase {} not supported", func_info.phase);

        const auto * aggr_func_type = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(argument_types[0].get());
        if (!aggr_func_type)
        {
            // FIXME. This is should be fixed. It's the case that count(distinct(xxx)) with other aggregate functions.
            // Gluten breaks the rule that intermediate result should have a special format name here.
            LOG_INFO(logger, "Intermediate aggregate function data is expected in phase {} for {}", func_info.phase, ch_func_name);

            auto arg_type = DB::removeNullable(argument_types[0]);
            if (auto * tupe_type = typeid_cast<const DB::DataTypeTuple *>(arg_type.get()))
                combinator_argument_types = tupe_type->getElements();

            auto agg_function = get_aggregate_function(ch_func_name, argument_types, aggr_func_type->getParameters());
            auto agg_intermediate_result_type = agg_function->getStateType();
            combinator_argument_types = {agg_intermediate_result_type};
        }
        else
        {
            // Special case for handling the intermedidate result from aggregate functions with filter.
            // It's safe to use AggregateFunctionxxx to parse intermediate result from AggregateFunctionxxxIf,
            // since they have the same binary representation
            // reproduce this case by
            //  select
            //    count(a),count(b), count(1), count(distinct(a)), count(distinct(b))
            //  from values (1, null), (2,2) as data(a,b)
            // with `first_value` enable
            if (endsWith(aggr_func_type->getFunction()->getName(), "If") && ch_func_name != aggr_func_type->getFunction()->getName())
            {
                auto original_args_types = aggr_func_type->getArgumentsDataTypes();
                combinator_argument_types = DataTypes(original_args_types.begin(), std::prev(original_args_types.end()));
                auto agg_function = get_aggregate_function(ch_func_name, combinator_argument_types, aggr_func_type->getParameters());
                combinator_argument_types = {agg_function->getStateType()};
            }
        }
        combinator_function_name += "PartialMerge";
    }
    else if (func_info.has_filter)
    {
        // Apply `If` aggregate function combinator on the original aggregate function.
        combinator_function_name += "If";
    }
    return {combinator_function_name, combinator_argument_types};
}

const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
    const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAG & actions_dag, bool with_nullability) const
{
    const auto & output_type = func_info.output_type;
    bool need_convert_type = !TypeParser::isTypeMatched(output_type, func_node->result_type, !with_nullability);
    if (need_convert_type)
    {
        func_node = ActionsDAGUtil::convertNodeType(actions_dag, func_node, TypeParser::parseType(output_type), func_node->result_name);
        actions_dag.addOrReplaceInOutputs(*func_node);
    }

    func_node = convertNanToNullIfNeed(func_info, func_node, actions_dag);

    if (output_type.has_decimal())
    {
        String checkDecimalOverflowSparkOrNull = "checkDecimalOverflowSparkOrNull";
        DB::ActionsDAG::NodeRawConstPtrs overflow_args
            = {func_node,
               expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().precision()),
               expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().scale())};
        func_node = toFunctionNode(actions_dag, checkDecimalOverflowSparkOrNull, func_node->result_name, overflow_args);
        actions_dag.addOrReplaceInOutputs(*func_node);
    }

    return func_node;
}

const DB::ActionsDAG::Node * AggregateFunctionParser::convertNanToNullIfNeed(
    const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAG & actions_dag) const
{
    if (getCHFunctionName(func_info) != "corr" || !func_node->result_type->isNullable())
        return func_node;

    /// result is nullable.
    /// if result is NaN, convert it to NULL.
    auto is_nan_func_node = toFunctionNode(actions_dag, "isNaN", getUniqueName("isNaN"), {func_node});
    auto nullable_col = func_node->result_type->createColumn();
    nullable_col->insertDefault();
    const auto * null_node
        = &actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(nullable_col), func_node->result_type, getUniqueName("null")));
    DB::ActionsDAG::NodeRawConstPtrs convert_nan_func_args = {is_nan_func_node, null_node, func_node};
    func_node = toFunctionNode(actions_dag, "if", func_node->result_name, convert_nan_func_args);
    actions_dag.addOrReplaceInOutputs(*func_node);
    return func_node;
}

AggregateFunctionParserFactory & AggregateFunctionParserFactory::instance()
{
    static AggregateFunctionParserFactory factory;
    return factory;
}

void AggregateFunctionParserFactory::registerAggregateFunctionParser(const String & name, Value value)
{
    if (!parsers.emplace(name, value).second)
        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Aggregate function parser {} is already registered", name);
}

AggregateFunctionParserPtr AggregateFunctionParserFactory::get(const String & name, ParserContextPtr parser_context) const
{
    auto parser = tryGet(name, parser_context);
    if (!parser)
        throw DB::Exception(DB::ErrorCodes::UNKNOWN_FUNCTION, "Unknown aggregate function {}", name);
    return parser;
}

AggregateFunctionParserPtr AggregateFunctionParserFactory::tryGet(const String & name, ParserContextPtr parser_context) const
{
    auto it = parsers.find(name);
    if (it == parsers.end())
        return nullptr;
    return it->second(parser_context);
}
}