* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "expr_parser.h"
#include "common/checker.h"
#include "symengine/real_double.h"
namespace af {
ExpressionImplPtr ExprParser::ParserExpression() {
auto ret = ParserAddSubtract();
GE_ASSERT_NOTNULL(ret);
return ret;
}
graphStatus ExprParser::Init() {
GE_ASSERT_SUCCESS(scanner_.GetNextToken(currentToken_));
return ge::GRAPH_SUCCESS;
}
graphStatus ExprParser::Eat(TokenType type) {
GE_ASSERT(currentToken_.type == type);
GE_ASSERT_SUCCESS(scanner_.GetNextToken(currentToken_));
return ge::GRAPH_SUCCESS;
}
ExpressionImplPtr ExprParser::ParserFactor() {
switch (currentToken_.type) {
case TokenType::kIdentifier:
return ParserIdentifier();
case TokenType::kLparen:
return ParserLParen();
case TokenType::kMax:
return ParserMaxFunction();
case TokenType::kMin:
return ParserMinFunction();
case TokenType::kPow:
return ParserPowFunction();
case TokenType::kMod:
return ParserModFunction();
case TokenType::kLog:
return ParserLogFunction();
case TokenType::kCeil:
return ParserCeilFunction();
case TokenType::kFloor:
return ParserFloorFunction();
case TokenType::kAbs:
return ParserAbsFunction();
case TokenType::kRational:
return ParserRationalFunction();
case TokenType::kNumber:
return ParserNumber();
case TokenType::kMinus:
return ParserMinus();
case TokenType::kEq:
return ParserEqual();
case TokenType::kNe:
return ParserUnequal();
case TokenType::kLe:
return ParserLessEqual();
case TokenType::kLt:
return ParserLessThan();
case TokenType::kTrue:
case TokenType::kFalse:
return ParseConstBoolen();
case TokenType::kLogicalAnd:
return ParserLogicalAnd();
case TokenType::kLogicalOr:
return ParserLogicalOr();
default:
GELOGE(ge::PARAM_INVALID, "Unsupported operator %d when Parser factor.", currentToken_.type);
return nullptr;
}
}
ExpressionImplPtr ExprParser::ParserAddSubtract() {
auto node = ParserMulDivide();
GE_ASSERT_NOTNULL(node);
while (currentToken_.type == TokenType::kPlus || currentToken_.type == TokenType::kMinus) {
TokenType op = currentToken_.type;
GE_ASSERT_SUCCESS(Eat(op));
auto right = ParserMulDivide();
GE_ASSERT_NOTNULL(right);
switch (op) {
case TokenType::kPlus:
node = Add(node, right);
break;
case TokenType::kMinus:
node = Sub(node, right);
break;
default:
GELOGE(ge::PARAM_INVALID, "unsupported operator %d when parsing add and sub.", currentToken_.type);
return nullptr;
}
}
return node;
}
ExpressionImplPtr ExprParser::ParserMulDivide() {
auto node = ParserFactor();
GE_ASSERT_NOTNULL(node);
while (currentToken_.type == TokenType::kMultiply || currentToken_.type == TokenType::kDivide) {
TokenType op = currentToken_.type;
GE_ASSERT_SUCCESS(Eat(op));
auto right = ParserFactor();
GE_ASSERT_NOTNULL(right);
switch (op) {
case TokenType::kMultiply:
node = Mul(node, right);
GE_ASSERT_NOTNULL(node);
break;
case TokenType::kDivide:
node = Div(node, right);
GE_ASSERT_NOTNULL(node);
break;
default:
GELOGE(ge::PARAM_INVALID, "unsupported operator %d when parsing mul and divide.", currentToken_.type);
return nullptr;
}
}
return node;
}
ExpressionImplPtr ExprParser::ParserMaxFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kMax));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Max(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserMinFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kMin));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Min(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserEqual() {
GE_ASSERT_SUCCESS(Eat(TokenType::kEq));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Eq(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserUnequal() {
GE_ASSERT_SUCCESS(Eat(TokenType::kNe));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Ne(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserLessEqual() {
GE_ASSERT_SUCCESS(Eat(TokenType::kLe));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Le(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserLessThan() {
GE_ASSERT_SUCCESS(Eat(TokenType::kLt));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Lt(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserLogicalAnd() {
GE_ASSERT_SUCCESS(Eat(TokenType::kLogicalAnd));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
std::vector<ExpressionImplPtr> args;
args.push_back(std::move(arg1));
args.push_back(std::move(arg2));
return LogicalAnd(args);
}
ExpressionImplPtr ExprParser::ParserLogicalOr() {
GE_ASSERT_SUCCESS(Eat(TokenType::kLogicalOr));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
std::vector<ExpressionImplPtr> args;
args.push_back(std::move(arg1));
args.push_back(std::move(arg2));
return LogicalOr(args);
}
ExpressionImplPtr ExprParser::ParserPowFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kPow));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Pow(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserModFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kMod));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserAddSubtract();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserAddSubtract();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Mod(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserLogFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kLog));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Log(arg1);
}
ExpressionImplPtr ExprParser::ParserCeilFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kCeil));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Ceiling(arg1);
}
ExpressionImplPtr ExprParser::ParserFloorFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kFloor));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Floor(arg1);
}
ExpressionImplPtr ExprParser::ParserAbsFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kAbs));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Abs(arg1);
}
ExpressionImplPtr ExprParser::ParserRationalFunction() {
GE_ASSERT_SUCCESS(Eat(TokenType::kRational));
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto arg1 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kComma));
auto arg2 = ParserExpression();
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return Rational(arg1, arg2);
}
ExpressionImplPtr ExprParser::ParserNumber() {
const std::string &numberStr = currentToken_.value;
try {
if (numberStr.find('.') != std::string::npos) {
double value = std::stod(numberStr);
GE_ASSERT_SUCCESS(Eat(TokenType::kNumber));
return ExpressionImpl::CreateExpressionImpl(value);
} else {
int64_t value = std::stoll(numberStr);
GE_ASSERT_SUCCESS(Eat(TokenType::kNumber));
return ExpressionImpl::CreateExpressionImpl(value);
}
} catch (std::invalid_argument &) {
GELOGW("number str:%s is invalid", numberStr.c_str());
return nullptr;
} catch (std::out_of_range &) {
GELOGW("number str:%s is out_of_range", numberStr.c_str());
return nullptr;
}
}
ExpressionImplPtr ExprParser::ParserIdentifier() {
const std::string name{currentToken_.value};
GE_ASSERT_SUCCESS(Eat(TokenType::kIdentifier));
return ExpressionImpl::CreateExpressionImpl(name);
}
ExpressionImplPtr ExprParser::ParseConstBoolen() {
bool sym_value = currentToken_.value == "True" ? true : false;
GE_ASSERT_SUCCESS(Eat(currentToken_.type));
return ExpressionImpl::CreateExpressionImpl(sym_value);
}
ExpressionImplPtr ExprParser::ParserMinus() {
GE_ASSERT_SUCCESS(Eat(TokenType::kMinus));
ExpressionImplPtr node;
if (currentToken_.type == TokenType::kLparen) {
node = ParserLParen();
} else {
node = ParserFactor();
}
GE_ASSERT_NOTNULL(node);
return Neg(node);
}
ExpressionImplPtr ExprParser::ParserLParen() {
GE_ASSERT_SUCCESS(Eat(TokenType::kLparen));
auto node = ParserExpression();
GE_ASSERT_NOTNULL(node);
GE_ASSERT_SUCCESS(Eat(TokenType::kRparen));
return node;
}
}