* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnNullable.h>
#include <Common/Exception.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <base/sort.h>
namespace DB::ErrorCodes
{
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TYPE_MISMATCH;
extern const int ILLEGAL_COLUMN;
}
namespace local_engine
{
struct LambdaLess
{
const DB::IColumn & column;
DB::DataTypePtr type;
const DB::ColumnFunction & lambda;
explicit LambdaLess(const DB::IColumn & column_, DB::DataTypePtr type_, const DB::ColumnFunction & lambda_)
: column(column_), type(type_), lambda(lambda_) {}
bool operator()(size_t lhs, size_t rhs) const
{
auto left_value_col = DB::ColumnWithTypeAndName(oneRowColumn(lhs), type, "left");
auto right_value_col = DB::ColumnWithTypeAndName(oneRowColumn(rhs), type, "right");
auto cloned_lambda = lambda.cloneResized(1);
auto * lambda_ = typeid_cast<DB::ColumnFunction *>(cloned_lambda.get());
lambda_->appendArguments({std::move(left_value_col), std::move(right_value_col)});
auto compare_res_col = lambda_->reduce();
DB::Field field;
compare_res_col.column->get(0, field);
return field.safeGet<Int32>() < 0;
}
private:
ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
{
auto res = column.cloneEmpty();
res->insertFrom(column, i);
return std::move(res);
}
};
struct Less
{
const DB::IColumn & column;
explicit Less(const DB::IColumn & column_) : column(column_) { }
bool operator()(size_t lhs, size_t rhs) const
{
return column.compareAt(lhs, rhs, column, 1) < 0;
}
};
class FunctionSparkArraySort : public DB::IFunction
{
public:
static constexpr auto name = "arraySortSpark";
static DB::FunctionPtr create(DB::ContextPtr ) { return std::make_shared<FunctionSparkArraySort>(); }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }
void getLambdaArgumentTypes(DB::DataTypes & arguments) const override
{
if (arguments.size() < 2)
throw DB::Exception(DB::ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} requires as arguments a lambda function and an array", getName());
if (arguments.size() > 1)
{
const auto * lambda_function_type = DB::checkAndGetDataType<DB::DataTypeFunction>(arguments[0].get());
if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != 2)
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with 2 arguments, found {} instead.",
getName(), arguments[0]->getName());
auto array_nesteed_type = DB::checkAndGetDataType<DB::DataTypeArray>(arguments.back().get())->getNestedType();
DB::DataTypes lambda_args = {array_nesteed_type, array_nesteed_type};
arguments[0] = std::make_shared<DB::DataTypeFunction>(lambda_args);
}
}
DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() > 1)
{
const auto * lambda_function_type = checkAndGetDataType<DB::DataTypeFunction>(arguments[0].type.get());
if (!lambda_function_type)
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
}
return arguments.back().type;
}
DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t input_rows_count) const override
{
auto array_col = arguments.back().column;
auto array_type = arguments.back().type;
DB::ColumnPtr null_map = nullptr;
if (const auto * null_col = typeid_cast<const DB::ColumnNullable *>(array_col.get()))
{
null_map = null_col->getNullMapColumnPtr();
array_col = null_col->getNestedColumnPtr();
array_type = typeid_cast<const DB::DataTypeNullable *>(array_type.get())->getNestedType();
}
const auto * array_col_concrete = DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
if (!array_col_concrete)
{
const auto * aray_col_concrete_const = DB::checkAndGetColumnConst<DB::ColumnArray>(array_col.get());
if (!aray_col_concrete_const)
{
throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Expected array column, found {}", array_col->getName());
}
array_col = DB::recursiveRemoveLowCardinality(aray_col_concrete_const->convertToFullColumn());
array_col_concrete = DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
}
auto array_nested_type = DB::checkAndGetDataType<DB::DataTypeArray>(array_type.get())->getNestedType();
DB::ColumnPtr sorted_array_col = nullptr;
if (arguments.size() > 1)
sorted_array_col = executeWithLambda(*array_col_concrete, array_nested_type, *checkAndGetColumn<DB::ColumnFunction>(arguments[0].column.get()));
else
sorted_array_col = executeWithoutLambda(*array_col_concrete);
if (null_map)
{
sorted_array_col = DB::ColumnNullable::create(sorted_array_col, null_map);
}
return sorted_array_col;
}
private:
static DB::ColumnPtr executeWithLambda(const DB::ColumnArray & array_col, DB::DataTypePtr array_nested_type, const DB::ColumnFunction & lambda)
{
const auto & offsets = array_col.getOffsets();
auto rows = array_col.size();
size_t nested_size = array_col.getData().size();
DB::IColumn::Permutation permutation(nested_size);
for (size_t i = 0; i < nested_size; ++i)
permutation[i] = i;
DB::ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
auto next_offset = offsets[i];
::sort(&permutation[current_offset],
&permutation[next_offset],
LambdaLess(array_col.getData(),
array_nested_type,
lambda));
current_offset = next_offset;
}
auto res = DB::ColumnArray::create(array_col.getData().permute(permutation, 0), array_col.getOffsetsPtr());
return res;
}
static DB::ColumnPtr executeWithoutLambda(const DB::ColumnArray & array_col)
{
const auto & offsets = array_col.getOffsets();
auto rows = array_col.size();
size_t nested_size = array_col.getData().size();
DB::IColumn::Permutation permutation(nested_size);
for (size_t i = 0; i < nested_size; ++i)
permutation[i] = i;
DB::ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
auto next_offset = offsets[i];
::sort(&permutation[current_offset],
&permutation[next_offset],
Less(array_col.getData()));
current_offset = next_offset;
}
auto res = DB::ColumnArray::create(array_col.getData().permute(permutation, 0), array_col.getOffsetsPtr());
return res;
}
String getName() const override
{
return name;
}
};
REGISTER_FUNCTION(ArraySortSpark)
{
factory.registerFunction<FunctionSparkArraySort>();
}
}