* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SparkFunctionCheckDecimalOverflow.h"
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <Core/DecimalFunctions.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int TYPE_MISMATCH;
}
}
namespace local_engine
{
using namespace DB;
struct CheckDecimalOverflowSpark
{
static constexpr auto name = "checkDecimalOverflowSpark";
};
struct CheckDecimalOverflowSparkOrNull
{
static constexpr auto name = "checkDecimalOverflowSparkOrNull";
};
enum class CheckExceptionMode
{
Throw,
Null
};
namespace
{
template <typename Name, CheckExceptionMode mode>
class FunctionCheckDecimalOverflow : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto exception_mode = mode;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionCheckDecimalOverflow>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 3; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & ) const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!isDecimal(arguments[0].type) || !isInteger(arguments[1].type) || !isInteger(arguments[2].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} {} {} of argument of function {}",
arguments[0].type->getName(),
arguments[1].type->getName(),
arguments[2].type->getName(),
getName());
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);
auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
if constexpr (exception_mode == CheckExceptionMode::Null)
{
if (!arguments[0].type->isNullable())
return std::make_shared<DataTypeNullable>(return_type);
}
return return_type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto & src_column = arguments[0];
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);
ColumnPtr result_column;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using FromDataType = typename Types::LeftType;
using ToDataType = typename Types::RightType;
if constexpr (IsDataTypeDecimal<FromDataType>)
{
using FromFieldType = typename FromDataType::FieldType;
using ColVecType = ColumnDecimal<FromFieldType>;
if (const ColVecType * col_vec = checkAndGetColumn<ColVecType>(src_column.column.get()))
{
executeInternal<FromFieldType, ToDataType>(*col_vec, result_column, input_rows_count, precision, scale);
return true;
}
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column while execute function {}", getName());
};
if (precision <= DecimalUtils::max_precision<Decimal32>)
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_column.type->getTypeId(), call);
else if (precision <= DecimalUtils::max_precision<Decimal64>)
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_column.type->getTypeId(), call);
else if (precision <= DecimalUtils::max_precision<Decimal128>)
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_column.type->getTypeId(), call);
else
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_column.type->getTypeId(), call);
if (!result_column)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {} with {}", getName(), src_column.type->getName());
return result_column;
}
private:
template <typename T, typename ToDataType>
static void executeInternal(
const ColumnDecimal<T> & col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32 precision, UInt32 scale_to)
{
using ToFieldType = typename ToDataType::FieldType;
using ToColumnType = typename ToDataType::ColumnType;
ColumnUInt8::MutablePtr col_null_map_to;
ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
auto scale_from = col_source.getScale();
if constexpr (exception_mode == CheckExceptionMode::Null)
{
col_null_map_to = ColumnUInt8::create(input_rows_count, false);
vec_null_map_to = &col_null_map_to->getData();
}
typename ToColumnType::MutablePtr col_to = ToColumnType::create(input_rows_count, scale_to);
auto & vec_to = col_to->getData();
vec_to.resize_exact(input_rows_count);
auto & datas = col_source.getData();
for (size_t i = 0; i < input_rows_count; ++i)
{
ToFieldType result;
bool success = convertToDecimalImpl<T, ToDataType>(datas[i], precision, scale_from, scale_to, result);
if (success)
vec_to[i] = static_cast<ToFieldType>(result);
else
{
vec_to[i] = static_cast<ToFieldType>(0);
if constexpr (exception_mode == CheckExceptionMode::Null)
(*vec_null_map_to)[i] = static_cast<UInt8>(1);
else
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
}
}
if constexpr (exception_mode == CheckExceptionMode::Null)
result_column = ColumnNullable::create(std::move(col_to), std::move(col_null_map_to));
else
result_column = std::move(col_to);
}
template <is_decimal FromFieldType, typename ToDataType>
requires(IsDataTypeDecimal<ToDataType>)
static bool convertToDecimalImpl(
const FromFieldType & decimal, UInt32 precision_to, UInt32 scale_from, UInt32 scale_to, typename ToDataType::FieldType & result)
{
if constexpr (std::is_same_v<FromFieldType, Decimal32>)
return convertDecimalsImpl<DataTypeDecimal<Decimal32>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
else if constexpr (std::is_same_v<FromFieldType, Decimal64>)
return convertDecimalsImpl<DataTypeDecimal<Decimal64>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
else if constexpr (std::is_same_v<FromFieldType, Decimal128>)
return convertDecimalsImpl<DataTypeDecimal<Decimal128>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
else
return convertDecimalsImpl<DataTypeDecimal<Decimal256>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
}
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
static bool convertDecimalsImpl(
const typename FromDataType::FieldType & value,
UInt32 precision_to,
UInt32 scale_from,
UInt32 scale_to,
typename ToDataType::FieldType & result)
{
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
using MaxFieldType = std::conditional_t<(sizeof(FromFieldType) > sizeof(ToFieldType)), FromFieldType, ToFieldType>;
using MaxNativeType = typename MaxFieldType::NativeType;
auto false_value = []() -> bool
{
if constexpr (exception_mode == CheckExceptionMode::Null)
return false;
else
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
};
MaxNativeType converted_value;
if (scale_to > scale_from)
{
converted_value = DecimalUtils::scaleMultiplier<MaxNativeType>(scale_to - scale_from);
if (common::mulOverflow(static_cast<MaxNativeType>(value.value), converted_value, converted_value))
return false_value();
}
else if (scale_to == scale_from)
converted_value = value.value;
else
converted_value = value.value / DecimalUtils::scaleMultiplier<MaxNativeType>(scale_from - scale_to);
MaxNativeType pow10 = intExp10OfSize<MaxNativeType>(precision_to);
if (converted_value <= -pow10 || converted_value >= pow10)
return false_value();
result = static_cast<typename ToFieldType::NativeType>(converted_value);
return true;
}
};
using FunctionCheckDecimalOverflowThrow = FunctionCheckDecimalOverflow<CheckDecimalOverflowSpark, CheckExceptionMode::Throw>;
using FunctionCheckDecimalOverflowOrNull = FunctionCheckDecimalOverflow<CheckDecimalOverflowSparkOrNull, CheckExceptionMode::Null>;
}
REGISTER_FUNCTION(CheckDecimalOverflowSpark)
{
factory.registerFunction<FunctionCheckDecimalOverflowThrow>(FunctionDocumentation{.description = R"(
Check decimal precision is overflow. If overflow throws exception.
)"});
factory.registerFunction<FunctionCheckDecimalOverflowOrNull>(FunctionDocumentation{.description = R"(
Check decimal precision is overflow. If overflow return `NULL`.
)"});
}
}