/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include "SparkFunctionDecimalBinaryOperator.h"

#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <Core/DecimalFunctions.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/castTypeToEither.h>
#include <Common/CurrentThread.h>

#if USE_EMBEDDED_COMPILER
#include <DataTypes/Native.h>
#include <llvm/IR/IRBuilder.h>
#endif

namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int TYPE_MISMATCH;
extern const int LOGICAL_ERROR;
}

}

namespace local_engine
{

template <typename Op1, typename Op2>
struct IsSameOperation
{
    static constexpr bool value = std::is_same_v<Op1, Op2>;
};

template <typename Op>
struct SparkIsOperation
{
    static constexpr bool plus = IsSameOperation<Op, DecimalPlusImpl>::value;
    static constexpr bool minus = IsSameOperation<Op, DecimalMinusImpl>::value;
    static constexpr bool plus_minus = IsSameOperation<Op, DecimalPlusImpl>::value || IsSameOperation<Op, DecimalMinusImpl>::value;
    static constexpr bool multiply = IsSameOperation<Op, DecimalMultiplyImpl>::value;
    static constexpr bool division = IsSameOperation<Op, DecimalDivideImpl>::value;
    static constexpr bool modulo = IsSameOperation<Op, DecimalModuloImpl>::value;
};

using namespace DB;

namespace
{
enum class OpCase : uint8_t
{
    Vector,
    LeftConstant,
    RightConstant
};

enum class OpMode : uint8_t
{
    Default,
    Effect
};


template <typename Operation, OpMode Mode>
struct SparkDecimalBinaryOperation
{
private:
    static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus_minus;
    static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply;
    static constexpr bool is_division = SparkIsOperation<Operation>::division;
    static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo;

public:
    static size_t getMaxScaled(size_t left_scale, size_t right_scale, size_t result_scale)
    {
        if constexpr (is_multiply)
            return left_scale + right_scale;
        else
            return std::max(result_scale, std::max(left_scale, right_scale));
    }

    template <typename LeftDataType, typename RightDataType, typename ResultDataType>
    static bool shouldPromoteTo256(const LeftDataType & left_type, const RightDataType & right_type, const ResultDataType & result_type)
    {
        auto p1 = left_type.getPrecision();
        auto s1 = left_type.getScale();
        auto p2 = right_type.getPrecision();
        auto s2 = right_type.getScale();

        size_t precision;
        if constexpr (is_plus_minus)
            precision = std::max<size_t>(s1, s2) + std::max<size_t>(p1 - s1, p2 - s2) + 1;
        else if constexpr (is_multiply)
            precision = p1 + p2 + 1;
        else if constexpr (is_division)
            precision = p1 - s1 + s2 + std::max<size_t>(6, s1 + p2 + 1);
        else if constexpr (is_modulo)
            precision = std::min<size_t>(p1 - s1, p2 - s2) + std::max<size_t>(s1, s2);
        else
            throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown decimal binary operation");

        if (precision > DataTypeDecimal128::maxPrecision())
            return true;

        return false;
    }

    template <typename LeftDataType, typename RightDataType, typename ResultDataType>
    static ColumnPtr executeDecimal(
        const ColumnsWithTypeAndName & arguments,
        const LeftDataType & left_type,
        const RightDataType & right_type,
        const ResultDataType & result_type)
    {
        using LeftFieldType = typename LeftDataType::FieldType;
        using RightFieldType = typename RightDataType::FieldType;
        using ResultFieldType = typename ResultDataType::FieldType;
        using ColVecLeft = ColumnDecimal<LeftFieldType>;
        using ColVecRight = ColumnDecimal<RightFieldType>;

        ColumnPtr col_left = arguments[0].column;
        ColumnPtr col_right = arguments[1].column;

        const ColumnConst * col_left_const = checkAndGetColumnConst<ColVecLeft>(col_left.get());
        const ColumnConst * col_right_const = checkAndGetColumnConst<ColVecRight>(col_right.get());
        const ColVecLeft * col_left_vec = checkAndGetColumn<ColVecLeft>(col_left.get());
        const ColVecRight * col_right_vec = checkAndGetColumn<ColVecRight>(col_right.get());

        size_t rows = col_left->size();
        size_t max_scale = getMaxScaled(left_type.getScale(), right_type.getScale(), result_type.getScale());

        bool calculate_with_i256 = false;
        if constexpr (Mode != OpMode::Effect)
        {
            if (shouldPromoteTo256(left_type, right_type, result_type))
                calculate_with_i256 = true;

            if (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision())
                calculate_with_i256 = true;
        }

        auto p1 = left_type.getPrecision();
        auto p2 = right_type.getPrecision();
        if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + max_scale - left_type.getScale()
            || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + max_scale - right_type.getScale())
            calculate_with_i256 = true;

        if (calculate_with_i256)
        {
            /// Use Int256 for calculation
            return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, Int256>(
                left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type);
        }
        else if constexpr (is_division)
        {
            /// Use Int128 for calculation
            return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, Int128>(
                left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type);
        }
        else
        {
            /// Use ResultNativeType for calculation
            return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, NativeType<ResultFieldType>>(
                left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type);
        }
    }

private:
    template <typename LeftDataType, typename RightDataType, typename ResultDataType, typename ScaledNativeType>
    static ColumnPtr executeDecimalImpl(
        const LeftDataType & left_type,
        const RightDataType & right_type,
        const ColumnConst * col_left_const,
        const ColumnConst * col_right_const,
        const ColumnDecimal<typename LeftDataType::FieldType> * col_left_vec,
        const ColumnDecimal<typename RightDataType::FieldType> * col_right_vec,
        size_t rows,
        const ResultDataType & result_type)
    {
        using LeftFieldType = typename LeftDataType::FieldType;
        using RightFieldType = typename RightDataType::FieldType;
        using ResultFieldType = typename ResultDataType::FieldType;
        using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>;

        size_t max_scale = getMaxScaled(left_type.getScale(), right_type.getScale(), result_type.getScale());

        ScaledNativeType scale_left = [&]
        {
            if constexpr (is_multiply)
                return ScaledNativeType{1};

            auto diff = max_scale - left_type.getScale();
            if constexpr (is_division)
                return DecimalUtils::scaleMultiplier<ScaledNativeType>(diff + max_scale);
            else
                return DecimalUtils::scaleMultiplier<ScaledNativeType>(diff);
        }();

        ScaledNativeType scale_right = [&]
        {
            if constexpr (is_multiply)
                return ScaledNativeType{1};
            else
                return DecimalUtils::scaleMultiplier<ScaledNativeType>(max_scale - right_type.getScale());
        }();

        ScaledNativeType unscale_result = [&]
        {
            auto result_scale = result_type.getScale();
            auto diff = max_scale - result_scale;
            chassert(diff >= 0);
            return DecimalUtils::scaleMultiplier<ScaledNativeType>(diff);
        }();

        ScaledNativeType max_value = intExp10OfSize<ScaledNativeType>(result_type.getPrecision());

        auto res_vec = ColVecResult::create(rows, result_type.getScale());
        auto & res_vec_data = res_vec->getData();
        auto res_null_map = ColumnUInt8::create(rows, 0);
        auto & res_nullmap_data = res_null_map->getData();

        if (col_left_vec && col_right_vec)
        {
                process<OpCase::Vector>(
                    col_left_vec->getData().data(),
                    col_right_vec->getData().data(),
                    res_vec_data,
                    res_nullmap_data,
                    rows,
                    scale_left,
                    scale_right,
                    unscale_result,
                    max_value);
        }
        else if (col_left_const && col_right_vec)
        {
            LeftFieldType left_value = col_left_const->getValue<LeftFieldType>();
            process<OpCase::LeftConstant>(
                &left_value,
                col_right_vec->getData().data(),
                res_vec_data,
                res_nullmap_data,
                rows,
                scale_left,
                scale_right,
                unscale_result,
                max_value);
        }
        else if (col_left_vec && col_right_const)
        {
            RightFieldType right_value = col_right_const->getValue<RightFieldType>();
            process<OpCase::RightConstant>(
                col_left_vec->getData().data(),
                &right_value,
                res_vec_data,
                res_nullmap_data,
                rows,
                scale_left,
                scale_right,
                unscale_result,
                max_value);
        }
        else
            throw Exception(
                ErrorCodes::LOGICAL_ERROR,
                "Unexpected argument types {} {} {}",
                left_type.getName(),
                right_type.getName(),
                result_type.getName());

        return ColumnNullable::create(std::move(res_vec), std::move(res_null_map));
    }

        template <
            OpCase op_case,
            typename LeftFieldType,
            typename RightFieldType,
            typename ResultFieldType,
            typename ScaledNativeType>
        static void NO_INLINE process(
            const LeftFieldType * __restrict left_data, // maybe scalar or vector
            const RightFieldType * __restrict right_data, // maybe scalar or vector
            PaddedPODArray<ResultFieldType> & __restrict res_vec_data, // should be vector
            NullMap & res_nullmap_data,
            size_t rows,
            const ScaledNativeType & scale_left,
            const ScaledNativeType & scale_right,
            const ScaledNativeType & unscale_result,
            const ScaledNativeType & max_value)
        {
            using ResultNativeType = NativeType<ResultFieldType>;

            if constexpr (op_case == OpCase::Vector)
            {
                for (size_t i = 0; i < rows; ++i)
                    res_nullmap_data[i] = !calculate(
                        static_cast<ScaledNativeType>(unwrap<op_case == OpCase::LeftConstant>(left_data, i)),
                        static_cast<ScaledNativeType>(unwrap<op_case == OpCase::RightConstant>(right_data, i)),
                        scale_left,
                        scale_right,
                        unscale_result,
                        max_value,
                        res_vec_data[i].value);
            }
            else if constexpr (op_case == OpCase::LeftConstant)
            {
                ScaledNativeType scaled_left
                    = applyScaled(static_cast<ScaledNativeType>(unwrap<op_case == OpCase::LeftConstant>(left_data, 0)), scale_left);

                for (size_t i = 0; i < rows; ++i)
                    res_nullmap_data[i] = !calculate(
                        scaled_left,
                        static_cast<ScaledNativeType>(unwrap<op_case == OpCase::RightConstant>(right_data, i)),
                        static_cast<ScaledNativeType>(1),
                        scale_right,
                        unscale_result,
                        max_value,
                        res_vec_data[i].value);
            }
            else if constexpr (op_case == OpCase::RightConstant)
            {
                ScaledNativeType scaled_right
                    = applyScaled(static_cast<ScaledNativeType>(unwrap<op_case == OpCase::RightConstant>(right_data, 0)), scale_right);

                for (size_t i = 0; i < rows; ++i)
                    res_nullmap_data[i] = !calculate(
                        static_cast<ScaledNativeType>(unwrap<op_case == OpCase::LeftConstant>(left_data, i)),
                        scaled_right,
                        scale_left,
                        static_cast<ScaledNativeType>(1),
                        unscale_result,
                        max_value,
                        res_vec_data[i].value);
            }
    }

    template <
        typename ScaledNativeType,
        typename ResultNativeType>
    static ALWAYS_INLINE bool calculate(
        const ScaledNativeType & left,
        const ScaledNativeType & right,
        const ScaledNativeType & scale_left,
        const ScaledNativeType & scale_right,
        const ScaledNativeType & unscale_result,
        const ScaledNativeType & max_value,
        ResultNativeType & res)
    {
        auto scaled_left = scale_left > 1 ? applyScaled(left, scale_left) : left;
        auto scaled_right = scale_right > 1 ? applyScaled(right, scale_right) : right;

        ScaledNativeType c_res = 0;
        auto success = Operation::template apply<>(scaled_left, scaled_right, c_res);
        if (!success)
            return false;

        if (unscale_result > 1)
            c_res = applyUnscaled(c_res, unscale_result);

        res = static_cast<ResultNativeType>(c_res);

        if constexpr (std::is_same_v<ScaledNativeType, Int256> || is_division)
            return c_res > -max_value && c_res < max_value;
        else
            return true;
    }

    /// Unwrap underlying native type from decimal type
    template <bool is_scalar, typename E>
    static auto unwrap(const E * elem, size_t i)
    {
        if constexpr (is_scalar)
            return elem->value;
        else
            return elem[i].value;
    }


    template <typename T>
    static ALWAYS_INLINE T applyScaled(T n, T scale)
    {
        chassert(scale != 0);

        T res;
        DecimalMultiplyImpl::apply(n, scale, res);
        return res;
    }

    template <typename T>
    static ALWAYS_INLINE T applyUnscaled(T n, T scale)
    {
        chassert(scale != 0);

        T res;
        DecimalDivideImpl::apply(n, scale, res);
        return res;
    }
};

/// TODO(taiyang-li): implement JIT for binary deicmal arithmetic functions
template <class Operation, typename Name, OpMode mode = OpMode::Default>
class SparkFunctionDecimalBinaryArithmetic final : public IFunction
{
    static constexpr bool is_plus = SparkIsOperation<Operation>::plus;
    static constexpr bool is_minus = SparkIsOperation<Operation>::minus;
    static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus || SparkIsOperation<Operation>::minus;
    static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply;
    static constexpr bool is_division = SparkIsOperation<Operation>::division;
    static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo;

public:
    static constexpr auto name = Name::name;

    static FunctionPtr create(ContextPtr context_) { return std::make_shared<SparkFunctionDecimalBinaryArithmetic>(context_); }

    explicit SparkFunctionDecimalBinaryArithmetic(ContextPtr context_) : context(context_) { }

    String getName() const override { return name; }
    size_t getNumberOfArguments() const override { return 3; }
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
    bool useDefaultImplementationForConstants() const override { return true; }
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {2}; }

    DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
    {
        if (arguments.size() != 3)
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' expects 3 arguments", getName());

        if (!isDecimal(arguments[0]) || !isDecimal(arguments[1]) || !isDecimal(arguments[2]))
            throw Exception(
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "Illegal type {} {} {} of argument of function {}",
                arguments[0]->getName(),
                arguments[1]->getName(),
                arguments[2]->getName(),
                getName());

        return makeNullable(arguments[2]);
    }

    // executeImpl2
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
    {
        const auto & left_argument = arguments[0];
        const auto & right_argument = arguments[1];

        const auto * left_generic = left_argument.type.get();
        const auto * right_generic = right_argument.type.get();

        ColumnPtr res;
        bool valid = castTripleTypes(
            left_generic,
            right_generic,
            removeNullable(arguments[2].type).get(),
            [&](const auto & left, const auto & right, const auto & result) {
                return (res = SparkDecimalBinaryOperation<Operation, mode>::template executeDecimal<>(arguments, left, right, result))
                    != nullptr;
            });

        if (!valid)
        {
            // This is a logical error, because the types should have been checked
            // by getReturnTypeImpl().
            throw Exception(
                ErrorCodes::LOGICAL_ERROR,
                "Arguments of '{}' have incorrect data types: '{}' of type '{}',"
                " '{}' of type '{}'",
                getName(),
                left_argument.name,
                left_argument.type->getName(),
                right_argument.name,
                right_argument.type->getName());
        }

        return res;
    }

#if USE_EMBEDDED_COMPILER
    virtual ColumnNumbers getArgumentsThatDontParticipateInCompilation(const DataTypes & /*types*/) const { return {2}; }

    bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override
    {
        const auto & denull_left_type = arguments[0];
        const auto & denull_right_type = arguments[1];
        const auto & denull_result_type = removeNullable(result_type);
        if (!canBeNativeType(denull_left_type) || !canBeNativeType(denull_right_type) || !canBeNativeType(denull_result_type))
            return false;

        return castTripleTypes(
            denull_left_type.get(),
            denull_right_type.get(),
            denull_result_type.get(),
            [&](const auto & left_type, const auto & right_type, const auto & result_type)
            {
                using LeftDataType = std::decay_t<decltype(left_type)>;
                using RightDataType = std::decay_t<decltype(right_type)>;
                using ResultDataType = std::decay_t<decltype(result_type)>;
                using LeftFieldType = typename LeftDataType::FieldType;
                using RightFieldType = typename RightDataType::FieldType;
                using ResultFieldType = typename ResultDataType::FieldType;

                size_t max_scale = SparkDecimalBinaryOperation<Operation, mode>::getMaxScaled(
                    left_type.getScale(), right_type.getScale(), result_type.getScale());
                auto p1 = left_type.getPrecision();
                auto p2 = right_type.getPrecision();
                if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + max_scale - left_type.getScale()
                    || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + max_scale - right_type.getScale())
                    return false;

                if (SparkDecimalBinaryOperation<Operation, mode>::shouldPromoteTo256(left_type, right_type, result_type)
                    || (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision()))
                    return false;

                return true;
            });
    }

    llvm::Value *
    compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override
    {
        const auto & denull_left_type = arguments[0].type;
        const auto & denull_right_type = arguments[1].type;
        const auto & denull_result_type = removeNullable(result_type);
        llvm::Value * nullable_result = nullptr;

        castTripleTypes(
            denull_left_type.get(),
            denull_right_type.get(),
            denull_result_type.get(),
            [&](const auto & left_type, const auto & right_type, const auto & result_type)
            {
                using LeftDataType = std::decay_t<decltype(left_type)>;
                using RightDataType = std::decay_t<decltype(right_type)>;
                using ResultDataType = std::decay_t<decltype(result_type)>;
                using LeftFieldType = typename LeftDataType::FieldType;
                using RightFieldType = typename RightDataType::FieldType;
                using ResultFieldType = typename ResultDataType::FieldType;
                using LeftNativeType = NativeType<LeftFieldType>;
                using RightNativeType = NativeType<RightFieldType>;
                using ResultNativeType = NativeType<ResultFieldType>;

                size_t max_scale = SparkDecimalBinaryOperation<Operation, mode>::getMaxScaled(
                    left_type.getScale(), right_type.getScale(), result_type.getScale());
                auto p1 = left_type.getPrecision();
                auto p2 = right_type.getPrecision();
                bool calculate_with_256 = false;
                if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + max_scale - left_type.getScale()
                    || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + max_scale - right_type.getScale())
                    calculate_with_256 = true;

                if (SparkDecimalBinaryOperation<Operation, mode>::shouldPromoteTo256(left_type, right_type, result_type)
                    || (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision()) || calculate_with_256)
                    nullable_result = compileHelper<Int256>(builder, arguments, left_type, right_type, result_type);
                    // nullable_result = compileHelper<Int128>(builder, arguments, left_type, right_type, result_type);
                else if (is_division)
                    nullable_result = compileHelper<Int128>(builder, arguments, left_type, right_type, result_type);
                else
                    nullable_result = compileHelper<ResultNativeType>(builder, arguments, left_type, right_type, result_type);

                return true;
            });

        return nullable_result;
    }

    template <typename CalculateType, typename LeftDataType, typename RightDataType, typename ResultDataType>
    static llvm::Value * compileHelper(
        llvm::IRBuilderBase & builder,
        const ValuesWithType & arguments,
        const LeftDataType & left_type,
        const RightDataType & right_type,
        const ResultDataType & result_type)
    {
        auto & b = static_cast<llvm::IRBuilder<> &>(builder);
        DataTypePtr calculate_type = std::make_shared<DataTypeNumber<CalculateType>>();

        auto * left = nativeCast(b, arguments[0], calculate_type);
        auto * right = nativeCast(b, arguments[1], calculate_type);

        size_t max_scale = SparkDecimalBinaryOperation<Operation, mode>::getMaxScaled(
            left_type.getScale(), right_type.getScale(), result_type.getScale());

        CalculateType scale_left = [&]
        {
            if constexpr (is_multiply)
                return CalculateType{1};

            auto diff = max_scale - left_type.getScale();
            if constexpr (is_division)
                return DecimalUtils::scaleMultiplier<CalculateType>(diff + max_scale);
            else
                return DecimalUtils::scaleMultiplier<CalculateType>(diff);
        }();

        CalculateType scale_right = [&]
        {
            if constexpr (is_multiply)
                return CalculateType{1};
            else
                return DecimalUtils::scaleMultiplier<CalculateType>(max_scale - right_type.getScale());
        }();

        auto * scaled_left = b.CreateMul(left, getNativeConstant(b, scale_left));
        auto * scaled_right = b.CreateMul(right, getNativeConstant(b, scale_right));

        llvm::Value * scaled_result = nullptr;
        llvm::Value * is_null = llvm::ConstantInt::getFalse(b.getContext());
        if constexpr (is_plus)
            scaled_result = b.CreateAdd(scaled_left, scaled_right);
        else if constexpr (is_minus)
            scaled_result = b.CreateSub(scaled_left, scaled_right);
        else if constexpr (is_multiply)
            scaled_result = b.CreateMul(scaled_left, scaled_right);
        else if constexpr (is_division)
        {
            auto * zero = getNativeConstant(b, static_cast<CalculateType>(0));
            auto * is_zero = b.CreateICmpEQ(scaled_right, zero);

            scaled_result = b.CreateSDiv(scaled_left, scaled_right);
            is_null = is_zero;
        }
        else if constexpr (is_modulo)
        {
            auto * zero = getNativeConstant(b, static_cast<CalculateType>(0));
            auto * is_zero = b.CreateICmpEQ(scaled_right, zero);

            scaled_result = b.CreateSRem(scaled_left, scaled_right);
            is_null = is_zero;
        }

        auto result_scale = result_type.getScale();
        auto scale_diff = max_scale - result_scale;
        auto * unscaled_result = scaled_result;
        if (scale_diff)
        {
            auto scaled_diff = DecimalUtils::scaleMultiplier<CalculateType>(scale_diff);
            unscaled_result = b.CreateSDiv(scaled_result, getNativeConstant(b, scaled_diff));
        }

        /// check overflow
        if constexpr (std::is_same_v<CalculateType, Int256> || is_division)
        {
            auto max_value = intExp10OfSize<CalculateType>(result_type.getPrecision());
            auto * max_value_const = getNativeConstant(b, max_value);
            auto * is_overflow = b.CreateOr(
                b.CreateICmpSGE(unscaled_result, max_value_const), b.CreateICmpSLE(unscaled_result, b.CreateNeg(max_value_const)));
            auto * overflow_result = getNativeConstant(b, static_cast<CalculateType>(0));
            is_null = b.CreateOr(is_null, is_overflow);
        }

        auto * result = nativeCast(b, calculate_type, unscaled_result, result_type.getPtr());
        auto * nullable_type = toNativeType(b, makeNullable(result_type.getPtr()));
        auto * nullable_result = llvm::Constant::getNullValue(nullable_type);
        auto * nullablel_result_with_value = b.CreateInsertValue(nullable_result, result, {0});
        return b.CreateInsertValue(nullablel_result_with_value, is_null, {1});
    }

    template <is_integer T>
    static llvm::Constant * getNativeConstant(llvm::IRBuilderBase & builder, T element)
    {
        auto * type = llvm::Type::getIntNTy(builder.getContext(), sizeof(T) * 8);
        if constexpr (std::is_integral_v<T>)
        {
            return llvm::ConstantInt::get(type, static_cast<uint64_t>(element), true);
        }
        else
        {
            llvm::APInt value(type->getIntegerBitWidth(), element.items);
            return llvm::ConstantInt::get(type, value);
        }
    }
#endif // USE_EMBEDDED_COMPILER

private:
    template <typename F>
    static bool castTripleTypes(const IDataType * left, const IDataType * right, const IDataType * result, F && f)
    {
        return castType(
            left,
            [&](const auto & left_)
            {
                return castType(
                    right,
                    [&](const auto & right_) { return castType(result, [&](const auto & result_) { return f(left_, right_, result_); }); });
            });
    }

    static bool castType(const IDataType * type, auto && f)
    {
        using Types = TypeList<DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128, DataTypeDecimal256>;
        return castTypeToEither(Types{}, type, std::forward<decltype(f)>(f));
    }

    ContextPtr context;
};

}
}