/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "FunctionParser.h"
#include <memory>
#include <Core/Field.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/TypeParser.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include "ExpressionParser.h"

namespace DB
{

namespace ErrorCodes
{
extern const int UNKNOWN_FUNCTION;
extern const int NOT_IMPLEMENTED;
}
}

namespace local_engine
{
using namespace DB;

FunctionParser::FunctionParser(ParserContextPtr ctx) : parser_context(ctx)
{
    expression_parser = std::make_unique<ExpressionParser>(parser_context);
}

String FunctionParser::getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const
{
    // no meaning
    /// There is no any simple equivalent ch function.
    return "";
}

String FunctionParser::getUniqueName(const String & name) const
{
    return expression_parser->getUniqueName(name);
}


const DB::ActionsDAG::Node *
FunctionParser::addColumnToActionsDAG(DB::ActionsDAG & actions_dag, const DB::DataTypePtr & type, const DB::Field & field) const
{
    return expression_parser->addConstColumn(actions_dag, type, field);
}

const DB::ActionsDAG::Node *
FunctionParser::toFunctionNode(DB::ActionsDAG & action_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
    return expression_parser->toFunctionNode(action_dag, func_name, args);
}

const DB::ActionsDAG::Node * FunctionParser::toFunctionNode(
    DB::ActionsDAG & action_dag, const String & func_name, const String & result_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
    return expression_parser->toFunctionNode(action_dag, func_name, args, result_name);
}

const ActionsDAG::Node * FunctionParser::parseFunctionWithDAG(
    const substrait::Expression & rel, std::string & result_name, DB::ActionsDAG & actions_dag, bool keep_result) const
{
    const auto * node = expression_parser->parseFunction(rel.scalar_function(), actions_dag, keep_result);
    result_name = node->result_name;
    return node;
}

const DB::ActionsDAG::Node * FunctionParser::parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
    return expression_parser->parseExpression(actions_dag, rel);
}

std::pair<DataTypePtr, Field> FunctionParser::parseLiteral(const substrait::Expression_Literal & literal) const
{
    return LiteralParser::parse(literal);
}

ActionsDAG::NodeRawConstPtrs
FunctionParser::parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const
{
    ActionsDAG::NodeRawConstPtrs parsed_args;
    return expression_parser->parseFunctionArguments(actions_dag, substrait_func);
}


const ActionsDAG::Node * FunctionParser::parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const
{
    auto ch_func_name = getCHFunctionName(substrait_func);
    auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
    const auto * func_node = toFunctionNode(actions_dag, ch_func_name, parsed_args);
    return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
}

const ActionsDAG::Node * FunctionParser::convertNodeTypeIfNeeded(
    const substrait::Expression_ScalarFunction & substrait_func, const ActionsDAG::Node * func_node, ActionsDAG & actions_dag) const
{
    const auto & output_type = substrait_func.output_type();
    const ActionsDAG::Node * result_node = nullptr;

    auto convert_type_if_needed = [&]()
    {
        if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
        {
            auto result_type = TypeParser::parseType(substrait_func.output_type());
            if (DB::isDecimalOrNullableDecimal(result_type))
            {
                return ActionsDAGUtil::convertNodeType(
                    actions_dag,
                    func_node,
                    // as stated in isTypeMatched, currently we don't change nullability of the result type
                    func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)
                                                         : local_engine::removeNullable(result_type),
                    func_node->result_name,
                    CastType::accurateOrNull);
            }
            else
            {
                return ActionsDAGUtil::convertNodeType(
                    actions_dag,
                    func_node,
                    // as stated in isTypeMatched, currently we don't change nullability of the result type
                    func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type))
                                                         : DB::removeNullable(TypeParser::parseType(output_type)),
                    func_node->result_name);
            }
        }
        else
            return func_node;
    };
    result_node = convert_type_if_needed();

    /// Notice that in CH Bool and UInt8 have different serialization and deserialization methods, which will cause issue when executing cast(bool as string) in spark in spark.
    auto convert_uint8_to_bool_if_needed = [&]() -> const auto *
    {
        auto * mutable_result_node = const_cast<ActionsDAG::Node *>(result_node);
        auto denull_result_type = DB::removeNullable(result_node->result_type);
        if (isUInt8(denull_result_type) && output_type.has_bool_())
        {
            auto bool_type = DB::DataTypeFactory::instance().get("Bool");
            if (result_node->result_type->isNullable())
                bool_type = DB::makeNullable(bool_type);

            mutable_result_node->result_type = std::move(bool_type);
            return mutable_result_node;
        }
        else
            return result_node;
    };
    result_node = convert_uint8_to_bool_if_needed();

    return result_node;
}

void FunctionParserFactory::registerFunctionParser(const String & name, Value value)
{
    if (!parsers.emplace(name, value).second)
        throw Exception(ErrorCodes::LOGICAL_ERROR, "FunctionParserFactory: function parser name '{}' is not unique", name);
}

FunctionParserPtr FunctionParserFactory::get(const String & name, ParserContextPtr ctx)
{
    auto res = tryGet(name, ctx);
    if (!res)
        throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unknown function parser {}", name);

    return res;
}

FunctionParserPtr FunctionParserFactory::tryGet(const String & name, ParserContextPtr ctx)
{
    auto it = parsers.find(name);
    if (it != parsers.end())
    {
        auto creator = it->second;
        return creator(ctx);
    }
    else
        return {};
}

FunctionParserFactory & FunctionParserFactory::instance()
{
    static FunctionParserFactory factory;
    return factory;
}

}