* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* Description: print expression tree methods
*/
#include "SubstraitToOmniPlanValidator.h"
#include <google/protobuf/wrappers.pb.h>
#include <re2/re2.h>
#include <string>
#include <algorithm>
#include "expression/expr_verifier.h"
#include "expression/expr_printer.h"
#include "expression/expressions.h"
namespace omniruntime {
namespace {
const char *ExtractFileName(const char *file) { return strrchr(file, '/') ? strrchr(file, '/') + 1 : file; }
std::unique_ptr<re2::RE2> CompilePattern(const std::string &pattern)
{
return std::make_unique<re2::RE2>(re2::StringPiece(pattern), RE2::Quiet);
}
bool EnsureRegexIsCompatible(const std::string &pattern, std::string &error)
{
auto charClassStart = pattern.cend();
for (auto c = pattern.cbegin(); c < pattern.cend(); ++c) {
if (*c == '\\') {
++c;
} else if (*c == '[') {
if (charClassStart != pattern.cend()) {
error = "Not support character class union, intersection, "
"or difference ([a[b]], [a&&[b]], [a&&[^b]])";
return false;
}
charClassStart = c;
} else if (*c == ']' && charClassStart != pattern.cend() && charClassStart + 1 != c) {
charClassStart = pattern.cend();
}
}
return true;
}
bool ValidatePattern(const std::string &pattern, std::string &error)
{
auto re2 = CompilePattern(pattern);
if (!re2->ok()) {
error = "Pattern " + pattern + " compilation failed in RE2. Reason: " + re2->error();
return false;
}
return EnsureRegexIsCompatible(pattern, error);
}
#define LOG_VALIDATION_MSG_FROM_EXCEPTION(err) \
LogValidateMsg(Format("Validation failed due to exception caught at " \
"file:{} line:{} function:{}," \
"thrown from file:{} line:{} function:{}, reason:{}", \
ExtractFileName(__FILE__), __LINE__, __FUNCTION__, "err.file()", "err.line()", "err.function()", \
err.message()))
#define LOG_VALIDATION_MSG(reason) \
LogValidateMsg(Format("Validation failed at file:{}, line:{}, function:{}, reason:{}", ExtractFileName(__FILE__), \
__LINE__, __FUNCTION__, reason))
const std::unordered_set<std::string> kRegexFunctions = {"rlike", "like"};
const std::unordered_set<std::string> kBlackList = {"get_struct_field", "rpad"};
}
bool SubstraitToOmniPlanValidator::ParseOmniType(
const ::substrait::extensions::AdvancedExtension &extension, DataTypePtr &out)
{
::substrait::Type substraitType;
if (!extension.has_enhancement()) {
LOG_VALIDATION_MSG("Input type is not wrapped in enhancement.");
return false;
}
const auto &enhancement = extension.enhancement();
if (!enhancement.UnpackTo(&substraitType)) {
LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType.");
return false;
}
out = SubstraitParser::ParseType(substraitType);
return true;
}
bool SubstraitToOmniPlanValidator::ParseOmniType(
const ::substrait::extensions::AdvancedExtension &extension,
DataTypePtr &out, std::vector<DataTypeId> &excludes)
{
if (ParseOmniType(extension, out)) {
DataTypeId curTypeId = out->GetId();
for (auto dataTypeId : excludes) {
if (dataTypeId == curTypeId) {
return false;
}
}
return true;
}
return false;
}
bool SubstraitToOmniPlanValidator::FlattenSingleLevel(const DataTypePtr &type,
std::vector<DataTypePtr> &out, std::vector<DataTypeId> *excludes)
{
if (type->GetId() != OMNI_ROW) {
LOG_VALIDATION_MSG("Type is not a RowType.");
return false;
}
auto rowType = std::dynamic_pointer_cast<const RowType>(type);
if (!rowType) {
LOG_VALIDATION_MSG("Failed to cast to RowType.");
return false;
}
const auto& children = rowType->Children();
if (excludes == nullptr || excludes->empty()) {
out.insert(out.end(), children.begin(), children.end());
return true;
}
for(const auto& field : children) {
if(std::find(excludes->begin(), excludes->end(), field->GetId()) != excludes->end()) {
return false;
}
out.emplace_back(field);
}
return true;
}
bool SubstraitToOmniPlanValidator::ValidateRound(
const ::substrait::Expression::ScalarFunction &scalarFunction, const DataTypesPtr &inputType)
{
const auto &arguments = scalarFunction.arguments();
if (arguments.size() < 2) {
return false;
}
if (!arguments[1].value().has_literal()) {
LOG_VALIDATION_MSG("Round scale is expected.");
return false;
}
auto typeCase = arguments[1].value().literal().literal_type_case();
switch (typeCase) {
case ::substrait::Expression_Literal::LiteralTypeCase::kI32:
return (arguments[1].value().literal().i32() >= 0);
case ::substrait::Expression_Literal::LiteralTypeCase::kI64:
return (arguments[1].value().literal().i64() >= 0);
default:
LOG_VALIDATION_MSG("Round scale validation is not supported for type case " + std::to_string(typeCase));
return false;
}
}
bool SubstraitToOmniPlanValidator::ValidateExtractExpr(const std::vector<TypedExprPtr> ¶ms)
{
if (params.size() != 2) {
LOG_VALIDATION_MSG("Value expected in variant in ExtractExpr.");
return false;
}
auto functionArg = dynamic_cast<const LiteralExpr *>(params[0]);
if (functionArg) {
return true;
}
LOG_VALIDATION_MSG("Constant is expected to be the first parameter in extract.");
return false;
}
bool SubstraitToOmniPlanValidator::ValidateRegexExpr(
const std::string &name, const ::substrait::Expression::ScalarFunction &scalarFunction)
{
if (scalarFunction.arguments().size() < 2) {
LOG_VALIDATION_MSG("Wrong number of arguments for " + name);
}
const auto &patternArg = scalarFunction.arguments()[1].value();
if (!patternArg.has_literal() || !patternArg.literal().has_string()) {
LOG_VALIDATION_MSG("Pattern is not string literal for " + name);
return false;
}
const auto &pattern = patternArg.literal().string();
std::string error;
if (!ValidatePattern(pattern, error)) {
LOG_VALIDATION_MSG(name + " due to " + error);
return false;
}
return true;
}
bool SubstraitToOmniPlanValidator::ValidateScalarFunction(
const ::substrait::Expression::ScalarFunction &scalarFunction, const DataTypesPtr &inputType)
{
std::vector<TypedExprPtr> params;
params.reserve(scalarFunction.arguments().size());
for (const auto &argument : scalarFunction.arguments()) {
if (argument.has_value() && !ValidateExpression(argument.value(), inputType)) {
return false;
}
}
std::string funcSpec = SubstraitParser::FindFunctionSpec(planConverter.GetFunctionMap(),
scalarFunction.function_reference());
std::string funcName = SubstraitParser::GetNameBeforeDelimiter(funcSpec);
auto it = kBlackList.find(funcName);
if (it != kBlackList.end()) {
LOG_VALIDATION_MSG("Function is not support: " + funcName);
return false;
}
if (kRegexFunctions.find(funcName) != kRegexFunctions.end()) {
return ValidateRegexExpr(funcName, scalarFunction);
}
return true;
}
bool SubstraitToOmniPlanValidator::ValidateLiteral(
const ::substrait::Expression_Literal &literal, const DataTypesPtr &inputType)
{
if (literal.has_list()) {
for (auto child : literal.list().values()) {
if (!ValidateLiteral(child, inputType)) {
return false;
}
}
} else if (literal.has_map()) {
for (auto child : literal.map().key_values()) {
if (!ValidateLiteral(child.key(), inputType) || !ValidateLiteral(child.value(), inputType)) {
return false;
}
}
}
return true;
}
bool SubstraitToOmniPlanValidator::ValidateCast(
const ::substrait::Expression::Cast &castExpr, const DataTypesPtr &inputType)
{
if (!ValidateExpression(castExpr.input(), inputType)) {
return false;
}
return true;
}
bool SubstraitToOmniPlanValidator::ValidateIfThen(
const ::substrait::Expression_IfThen &ifThen, const DataTypesPtr &inputType)
{
auto ifs = ifThen.ifs();
for (const auto &subIfThen : ifs) {
if (!ValidateExpression(subIfThen.if_(), inputType) || !ValidateExpression(subIfThen.then(), inputType)) {
return false;
}
}
if (ifThen.has_else_() && !ValidateExpression(ifThen.else_(), inputType)) {
return false;
}
return true;
}
bool SubstraitToOmniPlanValidator::ValidateSingularOrList(
const ::substrait::Expression::SingularOrList &singularOrList, const DataTypesPtr &inputType)
{
for (const auto &option : singularOrList.options()) {
if (!option.has_literal()) {
LOG_VALIDATION_MSG("Option is expected as Literal.");
return false;
}
if (!ValidateLiteral(option.literal(), inputType)) {
return false;
}
}
return ValidateExpression(singularOrList.value(), inputType);
}
bool SubstraitToOmniPlanValidator::ValidateExpression(
const ::substrait::Expression &expression, const DataTypesPtr &inputType)
{
auto typeCase = expression.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kScalarFunction:
return ValidateScalarFunction(expression.scalar_function(), inputType);
case ::substrait::Expression::RexTypeCase::kLiteral:
return ValidateLiteral(expression.literal(), inputType);
case ::substrait::Expression::RexTypeCase::kCast:
return ValidateCast(expression.cast(), inputType);
case ::substrait::Expression::RexTypeCase::kIfThen:
return ValidateIfThen(expression.if_then(), inputType);
case ::substrait::Expression::RexTypeCase::kSingularOrList:
return ValidateSingularOrList(expression.singular_or_list(), inputType);
default:
return true;
}
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::WriteRel &writeRel)
{
if (writeRel.has_input() && !Validate(writeRel.input())) {
LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel.");
return false;
}
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
if (writeRel.has_named_table()) {
const auto &extension = writeRel.named_table().advanced_extension();
if (!ParseOmniType(extension, inputRowType) || !FlattenSingleLevel(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel.");
return false;
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::FetchRel &fetchRel)
{
if (fetchRel.has_advanced_extension()) {
const auto &extension = fetchRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
if (!ParseOmniType(extension, inputRowType) || !FlattenSingleLevel(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in FetchRel.");
return false;
}
}
if (fetchRel.offset() < 0 || fetchRel.count() < 0) {
LOG_VALIDATION_MSG("Offset and count should be valid in FetchRel.");
return false;
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::TopNRel &topNRel)
{
DataTypesPtr rowType = nullptr;
if (topNRel.has_advanced_extension()) {
const auto &extension = topNRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LOG_VALIDATION_MSG("Unsupported input types in TopNRel.");
return false;
}
int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::MakeNodeName(inputPlanNodeId, colIdx));
}
rowType = std::make_shared<DataTypes>(std::move(types));
}
if (topNRel.n() < 0) {
LOG_VALIDATION_MSG("N should be valid in TopNRel.");
return false;
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::ExpandRel &expandRel)
{
if (expandRel.has_input() && !Validate(expandRel.input())) {
LOG_VALIDATION_MSG("Input validation fails in ExpandRel.");
return false;
}
DataTypesPtr rowType = nullptr;
if (expandRel.has_advanced_extension()) {
const auto &extension = expandRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LOG_VALIDATION_MSG("Unsupported input types in ExpandRel.");
return false;
}
int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::MakeNodeName(inputPlanNodeId, colIdx));
}
rowType = std::make_shared<DataTypes>(std::move(types));
}
int32_t projectSize = 0;
for (const auto &fields : expandRel.fields()) {
std::vector<TypedExprPtr> expressions;
if (fields.has_switching_field()) {
auto projectExprs = fields.switching_field().duplicates();
expressions.reserve(projectExprs.size());
if (projectSize == 0) {
projectSize = projectExprs.size();
} else if (projectSize != projectExprs.size()) {
LOG_VALIDATION_MSG("SwitchingField expressions size should be constant in ExpandRel.");
return false;
}
for (const auto &projectExpr : projectExprs) {
const auto &typeCase = projectExpr.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
case ::substrait::Expression::RexTypeCase::kLiteral:
case ::substrait::Expression::RexTypeCase::kScalarFunction:
break;
default:
LOG_VALIDATION_MSG("Only field or literal is supported in project of ExpandRel.");
return false;
}
if (rowType) {
expressions.emplace_back(exprConverter_->ToOmniExpr(projectExpr, rowType));
}
}
if (rowType) {
ExprVerifier ev;
for (const auto &expression : expressions) {
if (!ev.VisitExpr(*expression)) {
return false;
}
}
}
} else {
LOG_VALIDATION_MSG("Only SwitchingField is supported in ExpandRel.");
return false;
}
}
return true;
}
bool ValidateBoundType(::substrait::Expression_WindowFunction_Bound boundType)
{
switch (boundType.kind_case()) {
case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing:
case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding:
case ::substrait::Expression_WindowFunction_Bound::kCurrentRow:
break;
case ::substrait::Expression_WindowFunction_Bound::kFollowing:
case ::substrait::Expression_WindowFunction_Bound::kPreceding:
default:
return false;
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::WindowRel &windowRel)
{
if (windowRel.has_input() && !Validate(windowRel.input())) {
LOG_VALIDATION_MSG("WindowRel input fails to Validate.");
return false;
}
if (!windowRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in WindowRel.");
return false;
}
const auto &extension = windowRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
if (!ParseOmniType(extension, inputRowType) || !FlattenSingleLevel(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in WindowRel.");
return false;
}
if (types.empty()) {
LOG_VALIDATION_MSG("Validation failed for empty input schema in WindowRel.");
return false;
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
std::vector<std::string> funcSpecs;
funcSpecs.reserve(windowRel.measures().size());
for (const auto &smea : windowRel.measures()) {
const auto &windowFunction = smea.measure();
funcSpecs.emplace_back(planConverter.FindFuncSpec(windowFunction.function_reference()));
SubstraitParser::ParseType(windowFunction.output_type());
for (const auto &arg : windowFunction.arguments()) {
auto typeCase = arg.value().rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
case ::substrait::Expression::RexTypeCase::kLiteral:
break;
case ::substrait::Expression::RexTypeCase::kScalarFunction: {
if (!ValidateScalarFunction(arg.value().scalar_function(), rowType)) return false;
break;
}
default:
LOG_VALIDATION_MSG("Only field, constant, scalar function is supported in window functions.");
return false;
}
}
switch (windowFunction.window_type()) {
case ::substrait::WindowType::ROWS:
case ::substrait::WindowType::RANGE:
break;
default:
LOG_VALIDATION_MSG("the window type only support ROWS and RANGE, and the "
"input type is " +
std::to_string(windowFunction.window_type()));
return false;
}
bool boundTypeSupported =
ValidateBoundType(windowFunction.upper_bound()) && ValidateBoundType(windowFunction.lower_bound());
if (!boundTypeSupported) {
LOG_VALIDATION_MSG("Found unsupported Bound Type: upper " +
std::to_string(windowFunction.upper_bound().kind_case()) + ", lower " +
std::to_string(windowFunction.lower_bound().kind_case()));
return false;
}
}
const auto &groupByExprs = windowRel.partition_expressions();
std::vector<TypedExprPtr> expressions;
expressions.reserve(groupByExprs.size());
for (const auto &expr : groupByExprs) {
auto expression = exprConverter_->ToOmniExpr(expr, rowType);
auto exprField = dynamic_cast<const FieldExpr *>(expression);
if (exprField == nullptr) {
LOG_VALIDATION_MSG("Only field is supported for partition key in Window Operator!");
return false;
} else {
expressions.emplace_back(expression);
}
}
ExprVerifier ev;
for (const auto &expression : expressions) {
if (!ev.VisitExpr(*expression)) {
return false;
}
}
const auto &sorts = windowRel.sorts();
for (const auto &sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
break;
default:
LOG_VALIDATION_MSG("in windowRel, unsupported Sort direction " + std::to_string(sort.direction()));
return false;
}
if (sort.has_expr()) {
auto expression = exprConverter_->ToOmniExpr(sort.expr(), rowType);
auto exprField = dynamic_cast<const FieldExpr *>(expression);
if (!exprField) {
LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator "
"only support field.");
return false;
}
if (!ev.VisitExpr(*expression)) {
return false;
}
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::SetRel &setRel)
{
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
switch (setRel.op()) {
case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
const auto &input = setRel.inputs(i);
if (!Validate(input)) {
LOG_VALIDATION_MSG("ProjectRel input");
return false;
}
}
if (!setRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in SetRel.");
return false;
}
const auto &extension = setRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<std::vector<DataTypePtr>> childrenTypes;
if (!ParseOmniType(extension, inputRowType, excludes)) {
LOG_VALIDATION_MSG("Validation failed for input types in SetRel.");
return false;
}
std::vector<DataTypesPtr> childrenRowTypes;
for (auto i = 0; i < childrenTypes.size(); ++i) {
auto &types = childrenTypes.at(i);
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::MakeNodeName(i, colIdx));
}
childrenRowTypes.push_back(std::make_shared<DataTypes>(std::move(types)));
}
for (auto i = 1; i < childrenRowTypes.size(); ++i) {
if (childrenRowTypes[i] != childrenRowTypes[0]) {
LOG_VALIDATION_MSG("All sources of the Set operation must have the "
"same output type: ");
return false;
}
}
return true;
}
default:
LOG_VALIDATION_MSG("Unsupported SetRel op: " + std::to_string(setRel.op()));
return false;
}
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::SortRel &sortRel)
{
if (sortRel.has_input() && !Validate(sortRel.input())) {
return false;
}
if (!sortRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in SortRel.");
return false;
}
const auto &extension = sortRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
if (!ParseOmniType(extension, inputRowType) || !FlattenSingleLevel(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in SortRel.");
return false;
}
if (types.empty()) {
LOG_VALIDATION_MSG("Validation failed for empty input schema in SortRel.");
return false;
}
for (const auto &type : types) {
switch (type->GetId()) {
case OMNI_SHORT:
case OMNI_INT:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE:
case OMNI_CHAR:
case OMNI_VARCHAR:
case OMNI_BOOLEAN:
case OMNI_DECIMAL64:
case OMNI_DECIMAL128:
case OMNI_ARRAY:
break;
default:
LOG_VALIDATION_MSG(
"Validation failed for input types " + std::to_string(type->GetId()) + " in SortRel.");
return false;
}
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
const auto &sorts = sortRel.sorts();
for (const auto &sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
break;
default:
LOG_VALIDATION_MSG("unsupported Sort direction " + std::to_string(sort.direction()));
return false;
}
if (sort.has_expr()) {
auto expression = exprConverter_->ToOmniExpr(sort.expr(), rowType);
ExprVerifier ev;
if (!ev.VisitExpr(*expression)) {
return false;
}
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::ProjectRel &projectRel)
{
if (projectRel.has_input() && !Validate(projectRel.input())) {
LOG_VALIDATION_MSG("ProjectRel input");
return false;
}
if (!projectRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in ProjectRel.");
return false;
}
const auto &extension = projectRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel.");
return false;
}
int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::MakeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
const auto &projectExprs = projectRel.expressions();
std::vector<TypedExprPtr> expressions;
expressions.reserve(projectExprs.size());
for (const auto &expr : projectExprs) {
if (!ValidateExpression(expr, rowType)) {
LOG_VALIDATION_MSG("substrait validation fail!");
return false;
}
expressions.emplace_back(exprConverter_->ToOmniExpr(expr, rowType));
}
ExprVerifier ev;
for (const auto &expression : expressions) {
if (!ev.VisitExpr(*expression)) {
LOG_VALIDATION_MSG("OmniExpr validation fail! " + ev.GetUnSupportedReason());
return false;
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::FilterRel &filterRel)
{
if (filterRel.has_input() && !Validate(filterRel.input())) {
LOG_VALIDATION_MSG("input of FilterRel validation fails");
return false;
}
if (!filterRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in FilterRel.");
return false;
}
const auto &extension = filterRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LOG_VALIDATION_MSG("Validation failed for input types in FilterRel.");
return false;
}
int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::MakeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
std::vector<TypedExprPtr> expressions;
if (!ValidateExpression(filterRel.condition(), rowType)) {
return false;
}
expressions.emplace_back(exprConverter_->ToOmniExpr(filterRel.condition(), rowType));
ExprVerifier ev;
for (const auto &expression : expressions) {
if (!ev.VisitExpr(*expression)) {
return false;
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::JoinRel &joinRel)
{
if (joinRel.has_left() && !Validate(joinRel.left())) {
LOG_VALIDATION_MSG("Validation fails for join left input.");
return false;
}
if (joinRel.has_right() && !Validate(joinRel.right())) {
LOG_VALIDATION_MSG("Validation fails for join right input.");
return false;
}
if (joinRel.has_advanced_extension() &&
SubstraitParser::ConfigSetInOptimization(joinRel.advanced_extension(), "isSMJ=")) {
switch (joinRel.type()) {
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI:
break;
default:
LOG_VALIDATION_MSG("Sort merge join type is not supported: " + std::to_string(joinRel.type()));
return false;
}
}
switch (joinRel.type()) {
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI:
break;
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
if (joinRel.has_advanced_extension() &&
SubstraitParser::ConfigSetInOptimization(joinRel.advanced_extension(), "isExistenceJoin=")) {
LOG_VALIDATION_MSG("ExistenceJoin is not supported: " + std::to_string(joinRel.type()));
return false;
}
break;
default:
LOG_VALIDATION_MSG("Join type is not supported: " + std::to_string(joinRel.type()));
return false;
}
if (!joinRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in JoinRel.");
return false;
}
const auto &extension = joinRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LOG_VALIDATION_MSG("Validation failed for input types in JoinRel.");
return false;
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
if (joinRel.has_expression()) {
std::vector<const ::substrait::Expression *> leftExprs;
std::vector<const ::substrait::Expression *> rightExprs;
planConverter.ExtractJoinKeys(joinRel.expression(), leftExprs, rightExprs);
}
if (joinRel.has_post_join_filter()) {
auto expression = exprConverter_->ToOmniExpr(joinRel.post_join_filter(), rowType);
ExprVerifier ev;
if (!ev.VisitExpr(*expression)) {
return false;
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::CrossRel &crossRel)
{
if (crossRel.has_left() && !Validate(crossRel.left())) {
LogValidateMsg("Native validation failed due to: validation fails for "
"cross join left input. ");
return false;
}
if (crossRel.has_right() && !Validate(crossRel.right())) {
LogValidateMsg("Native validation failed due to: validation fails for "
"cross join right input. ");
return false;
}
if (!crossRel.has_advanced_extension()) {
LogValidateMsg("Native validation failed due to: Input types are expected "
"in CrossRel.");
return false;
}
switch (crossRel.type()) {
case ::substrait::CrossRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_RIGHT:
break;
default:
LOG_VALIDATION_MSG("Unsupported Join type in CrossRel");
return false;
}
const auto &extension = crossRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LogValidateMsg("Native validation failed due to: Validation failed for "
"input types in CrossRel");
return false;
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
if (crossRel.has_expression()) {
auto expression = exprConverter_->ToOmniExpr(crossRel.expression(), rowType);
ExprVerifier ev;
if (!ev.VisitExpr(*expression)) {
return false;
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::AggregateRel &aggRel)
{
if (aggRel.has_input() && !Validate(aggRel.input())) {
LOG_VALIDATION_MSG("Input validation fails in AggregateRel.");
return false;
}
if (!aggRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in AggregateRel.");
return false;
}
const auto &extension = aggRel.advanced_extension();
DataTypePtr inputRowType;
std::vector<DataTypePtr> types;
std::vector<DataTypeId> excludes{DataTypeId::OMNI_ARRAY};
if (!ParseOmniType(extension, inputRowType, excludes) || !FlattenSingleLevel(inputRowType, types, &excludes)) {
LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
return false;
}
auto rowType = std::make_shared<DataTypes>(std::move(types));
ExprVerifier ev;
for (const auto &smea : aggRel.measures()) {
if (smea.has_filter()) {
const ::substrait::Expression& substraitFilter = smea.filter();
if (substraitFilter.ByteSizeLong() > 0) {
if (!ValidateExpression(substraitFilter, rowType)) {
LOG_VALIDATION_MSG("substraitFilter validation fail!");
return false;
}
auto omniFilter = exprConverter_->ToOmniExpr(substraitFilter, rowType);
if (!ev.VisitExpr(*omniFilter)) {
LOG_VALIDATION_MSG("omniFilter validation fail!");
return false;
}
}
}
const auto &aggFunction = smea.measure();
for (const auto &arg : aggFunction.arguments()) {
const auto& argValue = arg.value();
if (!ValidateExpression(argValue, rowType)) {
LOG_VALIDATION_MSG("substrait aggFunExpression validation fail!");
return false;
}
auto omniExpr = exprConverter_->ToOmniExpr(argValue, rowType);
if (dynamic_cast<const LiteralExpr *>(omniExpr) != nullptr ||
dynamic_cast<const FieldExpr *>(omniExpr) != nullptr) {
continue;
}
if (!ev.VisitExpr(*omniExpr)) {
LOG_VALIDATION_MSG("aggFunExpression validation fail!");
return false;
}
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::ReadRel &readRel)
{
if (readRel.has_filter()) {
std::vector<DataTypePtr> omniTypeList;
if (readRel.has_base_schema()) {
const auto& baseSchema = readRel.base_schema();
omniTypeList = SubstraitParser::ParseNamedStruct(baseSchema);
}
auto rowType = std::make_shared<DataTypes>(std::move(omniTypeList));
std::vector<TypedExprPtr> expressions;
if (!ValidateExpression(readRel.filter(), rowType)) {
return false;
}
}
return true;
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::Rel &rel)
{
if (rel.has_aggregate()) {
return Validate(rel.aggregate());
} else if (rel.has_project()) {
return Validate(rel.project());
} else if (rel.has_filter()) {
return Validate(rel.filter());
} else if (rel.has_join()) {
return Validate(rel.join());
} else if (rel.has_cross()) {
return Validate(rel.cross());
} else if (rel.has_read()) {
return Validate(rel.read());
} else if (rel.has_sort()) {
return Validate(rel.sort());
} else if (rel.has_expand()) {
return Validate(rel.expand());
} else if (rel.has_fetch()) {
return Validate(rel.fetch());
} else if (rel.has_top_n()) {
return Validate(rel.top_n());
} else if (rel.has_window()) {
return Validate(rel.window());
} else if (rel.has_write()) {
return Validate(rel.write());
} else if (rel.has_set()) {
return Validate(rel.set());
} else {
LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName());
return false;
}
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::RelRoot &relRoot)
{
if (relRoot.has_input()) {
return Validate(relRoot.input());
} else {
return false;
}
}
bool SubstraitToOmniPlanValidator::Validate(const ::substrait::Plan &plan)
{
try {
planConverter.ConstructFunctionMap(plan);
exprConverter_ = planConverter.GetExprConverter();
for (const auto &rel : plan.relations()) {
if (rel.has_root()) {
return Validate(rel.root());
} else if (rel.has_rel()) {
return Validate(rel.rel());
}
}
return false;
} catch (const OmniException &err) {
LOG_VALIDATION_MSG(err.what());
return false;
}
}
}