/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "SparkFunctionMonthsBetween.h"
#include <string>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/DateTimeTransforms.h>
#include <Functions/FunctionFactory.h>
#include <Functions/TransformDateTime64.h>
#include <Poco/Logger.h>
#include <Common/DateLUT.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include "Core/Field.h"
#include "base/Decimal.h"
#include "base/types.h"

namespace DB
{
namespace ErrorCodes
{
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
    extern const int NOT_IMPLEMENTED;
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
using namespace DB;
DB::DataTypePtr SparkFunctionMonthsBetween::getReturnTypeImpl(const DB::DataTypes & arguments) const
{
    if (arguments.size() != 3 && arguments.size() != 4)
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
            "Number of arguments for function {} doesn't match: passed {}, should be 3 or 4",
            getName(), arguments.size());

    if (!isDate(arguments[0]) && !isDate32(arguments[0]) && !isDateTime(arguments[0]) && !isDateTime64(arguments[0]))
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
            "First argument for function {} must be Date, Date32, DateTime or DateTime64",
            getName()
            );

    if (!isDate(arguments[1]) && !isDate32(arguments[1]) && !isDateTime(arguments[1]) && !isDateTime64(arguments[1]))
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
            "Second argument for function {} must be Date, Date32, DateTime or DateTime64",
            getName());

    if (arguments.size() == 4 && !isString(arguments[3]))
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
            "Fourth argument for function {} (timezone) must be String",
            getName());

    return DB::makeNullableSafe(std::make_shared<DataTypeFloat64>());
}

ALWAYS_INLINE Float64 roundTo8IfNeed(bool round, Float64 res)
{
    return round ? std::round(res * 1e8) / 1e8 : res;
}

Float64 monthsBetween(DateTime64 x, DateTime64 y, const DateLUTImpl & timezone, bool round)
{
    // We know that spark use microseconds, maybe round to 8 digits after point
    x /= 1000000;
    y /= 1000000;
    int x_year = timezone.toYear(x);
    int y_year = timezone.toYear(y);
    auto x_month = timezone.toMonth(x);
    auto y_month = timezone.toMonth(y);
    auto x_day = timezone.toDayOfMonth(x);
    auto y_day = timezone.toDayOfMonth(y);
    auto month_diff = static_cast<Float64>(x_year * 12 + x_month - y_year * 12 - y_month);
    if (x_day == y_day)
        return roundTo8IfNeed(round, month_diff);

    int x_to_month_end = timezone.daysInMonth(x);
    x_to_month_end -= x_day;
    int y_to_month_end = timezone.daysInMonth(y);
    y_to_month_end -= y_day;
    if (x_to_month_end == 0 && y_to_month_end == 0)
        return roundTo8IfNeed(round, month_diff);

    int day_diff = static_cast<int>(x_day) - y_day;
    auto x_seconds_in_day= x - timezone.makeDate(x_year, x_month, x_day);
    auto y_seconds_in_day= y - timezone.makeDate(y_year, y_month, y_day);
    auto seconds_diff = x_seconds_in_day - y_seconds_in_day;
    auto res = static_cast<Float64>(day_diff * 86400 + seconds_diff)/2678400.0 + month_diff;
    return roundTo8IfNeed(round, res);
}

DB::ColumnPtr SparkFunctionMonthsBetween::executeImpl(
    const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const
{
    const IColumn & x = *arguments[0].column;
    const IColumn & y = *arguments[1].column;
    const IColumn & round_off = *arguments[2].column;

    size_t rows = input_rows_count;
    auto res = result_type->createColumn();
    res->reserve(rows);
    std::string timezone_str = "";
    if (arguments.size() == 4 && rows) // We know that timezone is constant
        timezone_str = arguments[3].column->getDataAt(0).toString();
    auto & timezone = DateLUT::instance(timezone_str);

    for (size_t i = 0; i < rows; ++i)
    {
        DB::Field x_value;
        DB::Field y_value;
        x.get(i, x_value);
        y.get(i, y_value);
        if (x_value.isNull() || y_value.isNull()) [[unlikely]]
            res->insertDefault();
        else
        {
            DB::Field round_value;
            round_off.get(i, round_value);
            res->insert(monthsBetween(
                static_cast<DateTime64>(x_value.safeGet<DateTime64>()),
                static_cast<DateTime64>(y_value.safeGet<DateTime64>()),
                timezone,
                static_cast<bool>(round_value.safeGet<UInt8>())));
        }
    }
    return res;
}

REGISTER_FUNCTION(SparkFunctionMonthsBetween)
{
    factory.registerFunction<SparkFunctionMonthsBetween>();
}
}