* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SubstraitToOmniExpr.h"
#include "expression/parserhelper.h"
constexpr const char *SUBSTRAIT_PARSE_ERROR = "SUBSTRAIT_PARSE_ERROR";
namespace omniruntime {
DataTypePtr GetScalarType(const ::substrait::Expression::Literal &literal)
{
auto typeCase = literal.literal_type_case();
switch (typeCase) {
case ::substrait::Expression_Literal::LiteralTypeCase::kBoolean:
return BooleanType();
case ::substrait::Expression_Literal::LiteralTypeCase::kI16:
return ShortType();
case ::substrait::Expression_Literal::LiteralTypeCase::kI32:
return IntType();
case ::substrait::Expression_Literal::LiteralTypeCase::kI64:
return LongType();
case ::substrait::Expression_Literal::LiteralTypeCase::kFp64:
return DoubleType();
case ::substrait::Expression_Literal::LiteralTypeCase::kDecimal: {
auto precision = literal.decimal().precision();
auto scale = literal.decimal().scale();
if (precision <= DECIMAL64_DEFAULT_PRECISION) {
auto type = Decimal64Type(precision, scale);
return type;
} else {
auto type = Decimal128Type(precision, scale);
return type;
}
}
case ::substrait::Expression_Literal::LiteralTypeCase::kDate:
return Date32Type();
case ::substrait::Expression_Literal::LiteralTypeCase::kTimestamp:
return TimestampType();
case ::substrait::Expression_Literal::LiteralTypeCase::kString:
return VarcharType();
case ::substrait::Expression_Literal::LiteralTypeCase::kVarChar:
return VarcharType();
default:
OMNI_THROW(
"GET_SCALAR_TYPE_ERROR:", "the given typeCase is not supported: '{}' ", std::to_string(typeCase));
}
}
bool IsNullOnFailure(::substrait::Expression::Cast::FailureBehavior failureBehavior)
{
switch (failureBehavior) {
case ::substrait::Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_UNSPECIFIED:
case ::substrait::Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_THROW_EXCEPTION:
return false;
case ::substrait::Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_RETURN_NULL:
return true;
default:
OMNI_THROW("SUBSTRAIT_ERROR:", "The given failure behavior is NOT supported: '{}'",
std::to_string(failureBehavior));
}
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(
const ::substrait::Expression::FieldReference &substraitField, const DataTypesPtr &inputType)
{
auto typeCase = substraitField.reference_type_case();
switch (typeCase) {
case ::substrait::Expression::FieldReference::ReferenceTypeCase::kDirectReference: {
const auto &directRef = substraitField.direct_reference();
const auto *tmp = &directRef.struct_field();
auto idx = tmp->field();
return new FieldExpr(idx, inputType->GetType(idx));
}
default:
OMNI_THROW(
"SUBSTRAIT_ERROR:", "Substrait conversion not supported for Reference '{}'", std::to_string(typeCase));
}
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(
const ::substrait::Expression::ScalarFunction &substraitFunc, const DataTypesPtr &inputType)
{
const auto &omniFunction = SubstraitParser::FindOmniFunction(functionMap_, substraitFunc.function_reference());
const auto &outputType = SubstraitParser::ParseType(substraitFunc.output_type());
auto type = omniFunction.first;
auto funcName = omniFunction.second;
Operator op = StringToOperator(funcName);
std::vector<Expr *> args;
args.reserve(substraitFunc.arguments().size());
for (const auto &sArg : substraitFunc.arguments()) {
args.emplace_back(ToOmniExpr(sArg.value(), inputType));
}
if (type == IS_NOT_NULL_OMNI_EXPR_TYPE) {
OMNI_CHECK(args[0] != nullptr, "args[0] is null");
auto isNullExpr = new IsNullExpr(args[0]);
return new UnaryExpr(Operator::NOT, isNullExpr, std::make_shared<BooleanDataType>());
} else if (type == IS_NULL_OMNI_EXPR_TYPE) {
OMNI_CHECK(args[0] != nullptr, "args[0] is null");
return new IsNullExpr(args[0]);
} else if (type == UNARY_OMNI_EXPR_TYPE) {
OMNI_CHECK(args[0] != nullptr, "args[0] is null");
OMNI_CHECK(op != Operator::INVALIDOP, "the operator is INVALIDOP");
return new UnaryExpr(op, args[0], std::make_shared<BooleanDataType>());
} else if (type == BINARY_OMNI_EXPR_TYPE) {
OMNI_CHECK(outputType != nullptr, "outputType is null");
OMNI_CHECK(args[0] != nullptr, "args[0] is null");
OMNI_CHECK(op != Operator::INVALIDOP, "the operator is INVALIDOP");
if (args[1] == nullptr) {
delete args[0];
OMNI_THROW("SUBSTRAIT_ERROR:", "The args[1] in ScalarFunction is nullptr");
}
return new BinaryExpr(op, args[0], args[1], std::move(outputType));
} else if (type == FUNCTION_OMNI_EXPR_TYPE) {
if (funcName == "concat") {
return UnfoldConcatStringFunc(args, outputType);
}
if (funcName == "MakeDecimal" && args.size() == 2) {
return new FuncExpr(funcName, {args[0]}, std::move(outputType));
}
if ((funcName == "RLike") && args.size() == RLIKE_INPUT) {
auto secondArg = args[1];
if (secondArg->GetType() != ExprType::LITERAL_E) {
Expr::DeleteExprs(args);
OMNI_THROW("SUBSTRAIT_ERROR:", "The type of args[1] is not equal to LITERAL_E");
}
auto literalExpr = static_cast<LiteralExpr *>(secondArg);
}
if (funcName == "mm3hash" || funcName == "xxhash64") {
auto *func = new FuncExpr(funcName, {args[0], args[args.size() - 1]}, outputType);
for (int32_t i = 1; i < args.size() - 1; i++) {
func = new FuncExpr(funcName, {args[i], func}, outputType);
}
return func;
}
std::vector<DataTypeId> argTypes(args.size());
std::transform(args.begin(), args.end(), argTypes.begin(),
[](Expr *expr) -> DataTypeId { return expr->GetReturnTypeId(); });
return new FuncExpr(funcName, args, std::move(outputType));
} else if (type == COALESCE_OMNI_EXPR_TYPE) {
if (args.size() != COALESCE_INPUT) {
OMNI_THROW("SUBSTRAIT_ERROR:", "coalesce expression only support two input parameters");
}
OMNI_CHECK(args[0] != nullptr, "args[0] is null");
if (args[1] == nullptr) {
delete args[0];
OMNI_THROW("SUBSTRAIT_ERROR:", "The args[1] in COALESCE_OMNI_EXPR_TYPE is nullptr");
}
return new CoalesceExpr(args[0], args[1]);
} else if (type == HIVE_UDF_FUNCTION_OMNI_EXPR_TYPE) {
throw omniruntime::exception::OmniException(SUBSTRAIT_PARSE_ERROR, "The UDF function Unsupported yet");
} else {
OMNI_THROW(
"SUBSTRAIT_ERROR:", "function type {} and function {} is unsupported yet", std::to_string(type), funcName);
}
}
TypedExprPtr SubstraitOmniExprConverter::UnfoldConcatStringFunc(std::vector<Expr *> args,
DataTypePtr outputType)
{
int concatParams = 2;
int argSize = args.size();
if (argSize == concatParams) {
return new FuncExpr("concat", {args[0], args[1]}, std::move(outputType));
}
std::vector<Expr*> newArgs(args.begin() + 1, args.end());
TypedExprPtr ret = UnfoldConcatStringFunc(newArgs, outputType);
return new FuncExpr("concat", {args[0], ret}, std::move(outputType));
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(
const ::substrait::Expression::SingularOrList &singularOrList, const DataTypesPtr &inputType)
{
std::vector<Expr *> args;
args.push_back(ToOmniExpr(singularOrList.value(), inputType));
for (const auto &option : singularOrList.options()) {
Expr *arg = ToOmniExpr(option.literal());
if (arg != nullptr) {
args.push_back(arg);
} else {
Expr::DeleteExprs(args);
OMNI_THROW("SUBSTRAIT_ERROR:", "The OmniExpression of the singularOrList.literal here is null");
}
}
return new InExpr(args);
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(
const ::substrait::Expression::Cast &castExpr, const DataTypesPtr &inputType)
{
auto retType = SubstraitParser::ParseType(castExpr.type());
auto expr = ToOmniExpr(castExpr.input(), inputType, retType);
auto retTypeId = retType->GetId();
auto argReturnType = expr->GetReturnType();
if (retTypeId == argReturnType->GetId()) {
if (TypeUtil::IsStringType(argReturnType->GetId())) {
auto argWidth = static_cast<VarcharDataType *>(argReturnType.get())->GetWidth();
auto retWidth = static_cast<VarcharDataType *>(retType.get())->GetWidth();
if (argWidth <= retWidth) {
return expr;
}
} else if (TypeUtil::IsDecimalType(retTypeId)) {
auto argScale = static_cast<DecimalDataType *>(argReturnType.get())->GetScale();
auto argPrecision = static_cast<DecimalDataType *>(argReturnType.get())->GetPrecision();
auto retScale = static_cast<DecimalDataType *>(retType.get())->GetScale();
auto retPrecision = static_cast<DecimalDataType *>(retType.get())->GetPrecision();
if (argScale == retScale && argPrecision <= retPrecision) {
return expr;
}
} else {
return expr;
}
}
std::vector<Expr *> args;
args.push_back(expr);
std::vector<DataTypeId> argTypes(args.size());
std::transform(
args.begin(), args.end(), argTypes.begin(), [](Expr *expr) -> DataTypeId { return expr->GetReturnTypeId(); });
return new FuncExpr("CAST", args, std::move(retType));
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(const ::substrait::Expression::Literal &substraitLit, const DataTypePtr defaultType)
{
auto typeCase = substraitLit.literal_type_case();
switch (typeCase) {
case ::substrait::Expression_Literal::LiteralTypeCase::kBoolean:
return new LiteralExpr(substraitLit.boolean(), BooleanType());
case ::substrait::Expression_Literal::LiteralTypeCase::kI16:
return new LiteralExpr(static_cast<int16_t>(substraitLit.i16()), ShortType());
case ::substrait::Expression_Literal::LiteralTypeCase::kI32:
return new LiteralExpr(substraitLit.i32(), IntType());
case ::substrait::Expression_Literal::LiteralTypeCase::kI64:
return new LiteralExpr(substraitLit.i64(), LongType());
case ::substrait::Expression_Literal::LiteralTypeCase::kFp64:
return new LiteralExpr(substraitLit.fp64(), DoubleType());
case ::substrait::Expression_Literal::LiteralTypeCase::kDate:
return new LiteralExpr(substraitLit.date(), Date32Type());
case ::substrait::Expression_Literal::LiteralTypeCase::kTimestamp:
return new LiteralExpr(substraitLit.timestamp(), TimestampType());
case ::substrait::Expression_Literal::LiteralTypeCase::kString: {
auto *stringVal = new std::string(substraitLit.string());
return new LiteralExpr(stringVal, VarcharType(stringVal->length()));
}
case ::substrait::Expression_Literal::LiteralTypeCase::kDecimal: {
auto decimal = substraitLit.decimal().value();
auto precision = substraitLit.decimal().precision();
auto scale = substraitLit.decimal().scale();
int128_t decimalValue;
memcpy_s(&decimalValue, sizeof(int128_t), decimal.c_str(), sizeof(int128_t));
if (precision <= DECIMAL64_DEFAULT_PRECISION) {
return new LiteralExpr(static_cast<int64_t>(decimalValue), Decimal64Type(precision, scale));
} else {
auto *dec128String = new std::string(Uint128ToStr(decimalValue));
return new LiteralExpr(dec128String, Decimal128Type(precision, scale));
}
}
case ::substrait::Expression_Literal::LiteralTypeCase::kNull: {
DataTypePtr dataType;
if (defaultType != nullptr) {
dataType = defaultType;
} else {
dataType = SubstraitParser::ParseType(substraitLit.null());
}
LiteralExpr *expr;
if (TypeUtil::IsDecimalType(dataType->GetId())) {
auto precision = std::dynamic_pointer_cast<DecimalDataType>(dataType)->GetPrecision();
auto scale = std::dynamic_pointer_cast<DecimalDataType>(dataType)->GetScale();
expr = ParserHelper::GetDefaultValueForType(dataType->GetId(), precision, scale);
} else {
expr = ParserHelper::GetDefaultValueForType(dataType->GetId());
}
if (expr == nullptr) {
OMNI_THROW("SUBSTRAIT_ERROR:", "The LiteralExpr in kNull case here is null");
}
expr->isNull = true;
return expr;
}
default:
throw omniruntime::exception::OmniException(SUBSTRAIT_PARSE_ERROR,
"Substrait conversion not supported for type case '{}' " + std::to_string(typeCase));
}
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(
const ::substrait::Expression::IfThen &substraitIfThen, const DataTypesPtr &inputType, const int32_t index)
{
const auto& ifs = substraitIfThen.ifs();
Expr *cond = ToOmniExpr(ifs.Get(index).if_(), inputType);
if (cond == nullptr) {
return nullptr;
}
Expr *trueExpr = ToOmniExpr(ifs.Get(index).then(), inputType);
if (trueExpr == nullptr) {
delete cond;
return nullptr;
}
Expr *falseExpr = nullptr;
if (index == ifs.size() - 1) {
falseExpr = ToOmniExpr(substraitIfThen.else_(), inputType);
} else {
falseExpr = ToOmniExpr(substraitIfThen, inputType, index + 1);
}
if (falseExpr == nullptr) {
delete cond;
delete trueExpr;
return nullptr;
}
if (TypeUtil::IsStringType(falseExpr->GetReturnTypeId()) && falseExpr->GetType() == ExprType::LITERAL_E &&
static_cast<LiteralExpr *>(falseExpr)->stringVal->compare("null") == 0) {
delete falseExpr;
auto literalExpr = ParserHelper::GetDefaultValueForType(trueExpr->GetReturnTypeId());
if (literalExpr == nullptr) {
delete cond;
delete trueExpr;
literalExpr->isNull = true;
OMNI_THROW("substrait_error", "the literal expression in substraitIfThen case is null here");
}
return new IfExpr(cond, trueExpr, literalExpr);
}
return new IfExpr(cond, trueExpr, falseExpr);
}
TypedExprPtr SubstraitOmniExprConverter::ToOmniExpr(
const substrait::Expression &substraitExpr, const DataTypesPtr &inputType, DataTypePtr defaultType)
{
auto typeCase = substraitExpr.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kLiteral:
return ToOmniExpr(substraitExpr.literal(), defaultType);
case ::substrait::Expression::RexTypeCase::kScalarFunction:
return ToOmniExpr(substraitExpr.scalar_function(), inputType);
case ::substrait::Expression::RexTypeCase::kSelection:
return ToOmniExpr(substraitExpr.selection(), inputType);
case ::substrait::Expression::RexTypeCase::kCast:
return ToOmniExpr(substraitExpr.cast(), inputType);
case ::substrait::Expression::RexTypeCase::kIfThen:
return ToOmniExpr(substraitExpr.if_then(), inputType);
case ::substrait::Expression::RexTypeCase::kSingularOrList:
return ToOmniExpr(substraitExpr.singular_or_list(), inputType);
default:
OMNI_THROW(
"Substrait_Error:", "Substrait conversion not supported for Expression '{}'", std::to_string(typeCase));
}
}
}