* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
* Description: Expression Verifier
*/
#include "expr_verifier.h"
#include "codegen/func_registry.h"
using namespace omniruntime::expressions;
using namespace omniruntime::type;
namespace omniruntime {
namespace expressions {
bool ExprVerifier::VisitExpr(const Expr &e)
{
e.Accept(*this);
return this->supportedFlag;
}
bool ExprVerifier::VisitExpr(const std::shared_ptr<const Expr> &e)
{
e->Accept(*this);
return this->supportedFlag;
}
bool ExprVerifier::AreInvalidDataTypes(DataTypeId type1, DataTypeId type2)
{
return type1 != type2 && !(TypeUtil::IsStringType(type1) && TypeUtil::IsStringType(type2)) &&
!(TypeUtil::IsDecimalType(type1) && TypeUtil::IsDecimalType(type2));
}
void ExprVerifier::Visit(const LiteralExpr &literalExpr)
{
switch (literalExpr.GetReturnTypeId()) {
case OMNI_BYTE:
case OMNI_SHORT:
case OMNI_INT:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE:
case OMNI_CHAR:
case OMNI_VARCHAR:
case OMNI_BOOLEAN:
case OMNI_DECIMAL64:
case OMNI_DECIMAL128:
case OMNI_ARRAY:
this->supportedFlag = true;
break;
default:
this->supportedFlag = false;
break;
}
}
void ExprVerifier::Visit(const FieldExpr &fieldExpr)
{
switch (fieldExpr.GetReturnTypeId()) {
case OMNI_BYTE:
case OMNI_SHORT:
case OMNI_INT:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE:
case OMNI_CHAR:
case OMNI_VARCHAR:
case OMNI_BOOLEAN:
case OMNI_DECIMAL64:
case OMNI_DECIMAL128:
case OMNI_ARRAY:
this->supportedFlag = true;
break;
default:
this->supportedFlag = false;
this->unSupportedReason = "unSupported FieldExpr DataTypeId: "
+ std::to_string(static_cast<int>(fieldExpr.GetReturnTypeId()));
break;
}
}
void ExprVerifier::Visit(const UnaryExpr &unaryExpr)
{
if (!VisitExpr(*(unaryExpr.exp))) {
this->supportedFlag = false;
return;
}
switch (unaryExpr.op) {
case omniruntime::expressions::Operator::NOT:
this->supportedFlag = true;
break;
default:
this->supportedFlag = false;
break;
}
}
void ExprVerifier::Visit(const BinaryExpr &binaryExpr)
{
const type::DataType &leftType = *(binaryExpr.left->GetReturnType());
const type::DataType &rightType = *(binaryExpr.right->GetReturnType());
if (AreInvalidDataTypes(leftType.GetId(), rightType.GetId())) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*(binaryExpr.left))) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*(binaryExpr.right))) {
this->supportedFlag = false;
return;
}
if (binaryExpr.op == omniruntime::expressions::Operator::AND ||
binaryExpr.op == omniruntime::expressions::Operator::OR) {
this->supportedFlag = (binaryExpr.left->GetReturnTypeId() == binaryExpr.right->GetReturnTypeId() &&
binaryExpr.left->GetReturnTypeId() == DataTypeId::OMNI_BOOLEAN);
return;
}
if (binaryExpr.left->GetReturnTypeId() == OMNI_BYTE ||binaryExpr.left->GetReturnTypeId() == OMNI_SHORT ||
binaryExpr.left->GetReturnTypeId() == OMNI_INT || binaryExpr.left->GetReturnTypeId() == OMNI_LONG ||
binaryExpr.left->GetReturnTypeId() == OMNI_DATE32 || binaryExpr.left->GetReturnTypeId() == OMNI_DOUBLE) {
this->supportedFlag = true;
return;
} else if (TypeUtil::IsStringType(binaryExpr.left->GetReturnTypeId()) ||
binaryExpr.left->GetReturnTypeId() == OMNI_TIMESTAMP) {
switch (binaryExpr.op) {
case omniruntime::expressions::Operator::LT:
case omniruntime::expressions::Operator::GT:
case omniruntime::expressions::Operator::LTE:
case omniruntime::expressions::Operator::GTE:
case omniruntime::expressions::Operator::EQ:
case omniruntime::expressions::Operator::NEQ:
this->supportedFlag = true;
break;
default:
this->supportedFlag = false;
break;
}
return;
} else if (TypeUtil::IsDecimalType(binaryExpr.left->GetReturnTypeId())) {
this->supportedFlag = true;
return;
}
this->supportedFlag = false;
}
void ExprVerifier::Visit(const InExpr &inExpr)
{
Expr *toCompare = inExpr.arguments[0];
switch (toCompare->GetReturnTypeId()) {
case OMNI_BYTE:
case OMNI_SHORT:
case OMNI_INT:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE:
case OMNI_CHAR:
case OMNI_VARCHAR:
case OMNI_DECIMAL64:
case OMNI_DECIMAL128:
break;
default:
this->supportedFlag = false;
return;
}
if (!VisitExpr(*toCompare)) {
this->supportedFlag = false;
return;
}
for (size_t i = 1; i < inExpr.arguments.size(); i++) {
if (AreInvalidDataTypes(toCompare->GetReturnTypeId(), inExpr.arguments[i]->GetReturnTypeId())) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*(inExpr.arguments[i]))) {
this->supportedFlag = false;
return;
}
}
this->supportedFlag = true;
}
void ExprVerifier::Visit(const BetweenExpr &betweenExpr)
{
DataTypeId valueTypeId = betweenExpr.value->GetReturnTypeId();
if (AreInvalidDataTypes(valueTypeId, betweenExpr.lowerBound->GetReturnTypeId()) &&
AreInvalidDataTypes(valueTypeId, betweenExpr.upperBound->GetReturnTypeId())) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*betweenExpr.value)) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*betweenExpr.lowerBound)) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*betweenExpr.upperBound)) {
this->supportedFlag = false;
return;
}
this->supportedFlag = true;
}
void ExprVerifier::Visit(const IfExpr &ifExpr)
{
Expr *cond = ifExpr.condition;
Expr *ifTrue = ifExpr.trueExpr;
Expr *ifFalse = ifExpr.falseExpr;
if (!VisitExpr(*cond)) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*ifTrue)) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*ifFalse)) {
this->supportedFlag = false;
return;
}
this->supportedFlag = true;
}
void ExprVerifier::Visit(const CoalesceExpr &coalesceExpr)
{
Expr *value1Expr = coalesceExpr.value1;
Expr *value2Expr = coalesceExpr.value2;
if (!VisitExpr(*value1Expr)) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*value2Expr)) {
this->supportedFlag = false;
return;
}
this->supportedFlag = true;
}
void ExprVerifier::Visit(const IsNullExpr &isNullExpr)
{
Expr *valueExpr = isNullExpr.value;
if (!VisitExpr(*valueExpr)) {
this->supportedFlag = false;
return;
}
this->supportedFlag = true;
}
void ExprVerifier::Visit(const FuncExpr &funcExpr)
{
if (funcExpr.funcName == "LIKE") {
this->supportedFlag = false;
return;
}
int numArgs = funcExpr.arguments.size();
std::vector<DataTypeId> params;
for (int i = 0; i < numArgs; i++) {
params.push_back(funcExpr.arguments[i]->GetReturnTypeId());
if (!VisitExpr(*funcExpr.arguments[i])) {
this->supportedFlag = false;
return;
}
}
auto signature = FunctionSignature(funcExpr.funcName, params, funcExpr.GetReturnTypeId());
auto function = codegen::FunctionRegistry::LookupFunction(&signature);
if (function == nullptr) {
this->supportedFlag = false;
return;
}
this->supportedFlag = true;
}
void ExprVerifier::Visit(const SwitchExpr &switchExpr)
{
std::vector<std::pair<Expr*, Expr*>> whenClause = switchExpr.whenClause;
auto size = whenClause.size();
for (size_t i = 0; i < size; i++) {
Expr *cond = whenClause[i].first;
Expr *resExpr = whenClause[i].second;
if (!VisitExpr(*cond)) {
this->supportedFlag = false;
return;
}
if (!VisitExpr(*resExpr)) {
this->supportedFlag = false;
return;
}
}
Expr *elseExpr = switchExpr.falseExpr;
if (!VisitExpr(*elseExpr)) {
this->supportedFlag = false;
return;
}
this->supportedFlag = true;
}
}
}