* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "substrait/algebra.pb.h"
#include "substrait/capabilities.pb.h"
#include "substrait/extensions/extensions.pb.h"
#include "substrait/function.pb.h"
#include "substrait/parameterized_types.pb.h"
#include "substrait/plan.pb.h"
#include "substrait/type.pb.h"
#include "substrait/type_expressions.pb.h"
#include "type/data_types.h"
#include "util/omni_exception.h"
#include "operator/util/function_type.h"
#include "util/type_util.h"
namespace omniruntime {
enum SubstraitToOmniExprType {
IS_NULL_OMNI_EXPR_TYPE = 0,
IS_NOT_NULL_OMNI_EXPR_TYPE,
UNARY_OMNI_EXPR_TYPE,
BINARY_OMNI_EXPR_TYPE,
FUNCTION_OMNI_EXPR_TYPE,
COALESCE_OMNI_EXPR_TYPE,
HIVE_UDF_FUNCTION_OMNI_EXPR_TYPE
};
constexpr const char *SUBSTRAIT_PARSE_ERROR = "SUBSTRAIT_PARSE_ERROR";
class SubstraitParser {
public:
static std::vector<type::DataTypePtr> ParseNamedStruct(
const ::substrait::NamedStruct &namedStruct, bool asLowerCase = false);
static std::pair<SubstraitToOmniExprType, std::string> FindOmniFunction(
const std::unordered_map<uint64_t, std::string> &functionMap, uint64_t id);
static type::DataTypePtr ParseType(const ::substrait::Type &substraitType, bool asLowerCase = false, bool isNest = false);
static std::vector<std::string> MakeNames(const std::string &prefix, int size);
static std::string MakeNodeName(int nodeId, int colIdx);
static int GetIdxFromNodeName(const std::string &nodeName);
static std::string FindFunctionSpec(const std::unordered_map<uint64_t, std::string> &functionMap, uint64_t id);
static std::string GetNameBeforeDelimiter(const std::string &signature, const std::string &delimiter = ":");
static std::vector<std::string> GetSubFunctionTypes(const std::string &subFuncSpec);
static std::pair<SubstraitToOmniExprType, std::string> MapToOmniFunction(const std::string &substraitFunction);
static bool ConfigSetInOptimization(const ::substrait::extensions::AdvancedExtension &, const std::string &config);
static bool ConfigExistInOptimization(
const ::substrait::extensions::AdvancedExtension &, const std::string &config);
static std::vector<type::DataTypePtr> SigToTypes(const std::string &functionSig);
template <typename T>
static T GetLiteralValue(const ::substrait::Expression::Literal & );
static type::DataTypesPtr ParseStructType(const ::substrait::Type &substraitType);
static op::FunctionType ParseFunctionType(
const std::string &funcName, std::vector<substrait::Expression> &expressionNodes, bool isMergeCount);
static void AddStructDataType(
const ::substrait::Type &substraitType, std::vector<omniruntime::type::DataTypePtr> &outputDataTypes);
private:
static std::unordered_map<std::string, std::pair<SubstraitToOmniExprType, std::string>> substraitOmniFunctionMap;
static const std::unordered_map<std::string, std::string> typeMap;
static const uint32_t MAX_PRECISION_64 = 18;
static const uint32_t MAX_PRECISION_128 = 38;
};
}