* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/symbolizer/symbolic.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "common/checker.h"
#include "attribute_group/attr_group_shape_env.h"
namespace ge {
const char_t *kInvalidExpr = "invalid expression";
std::string SymbolicUtils::ToString(const Expression &e) {
auto ret = e.Str(StrType::kStrCpp);
return (ret != nullptr) ? ret.get() : kInvalidExpr;
}
TriBool SymbolicUtils::StaticCheckEq(const Expression &e1, const Expression &e2) {
return StaticCheckBool(sym::Eq(e1.Simplify(), e2.Simplify()));
}
TriBool SymbolicUtils::StaticCheckNe(const Expression &e1, const Expression &e2) {
return StaticCheckBool(sym::Ne(e1.Simplify(), e2.Simplify()));
}
TriBool SymbolicUtils::StaticCheckLt(const Expression &e1, const Expression &e2) {
return StaticCheckBool(sym::Lt(e1.Simplify(), e2.Simplify()));
}
TriBool SymbolicUtils::StaticCheckLe(const Expression &e1, const Expression &e2) {
return StaticCheckBool(sym::Le(e1.Simplify(), e2.Simplify()));
}
TriBool SymbolicUtils::StaticCheckGt(const Expression &e1, const Expression &e2) {
return StaticCheckBool(sym::Gt(e1.Simplify(), e2.Simplify()));
}
TriBool SymbolicUtils::StaticCheckGe(const Expression &e1, const Expression &e2) {
return StaticCheckBool(sym::Ge(e1.Simplify(), e2.Simplify()));
}
std::string SymbolicUtils::AsNumerDenomToString(const Expression &x) {
Expression numer;
Expression denom;
x.AsNumerDenom(numer, denom);
auto numer_str = ToString(numer);
auto denom_str = ToString(denom);
if (numer_str == kInvalidExpr || denom_str == kInvalidExpr) {
return kInvalidExpr;
}
return "(" + numer_str + ")" + "/" + "(" + denom_str + ")";
}
TriBool SymbolicUtils::StaticCheckBool(const Expression &expr) {
GE_ASSERT_TRUE(expr.IsBooleanExpr(), "Only boolean expr can do static check, expr: %s",
expr.Serialize().get());
bool value = false;
if (expr.IsConstExpr()) {
GE_ASSERT_TRUE(expr.GetConstValue(value));
return value ? TriBool::kTrue : TriBool::kFalse;
}
if (GetCurShapeEnvContext() == nullptr) {
GELOGW("Shape env is nullptr, cannot do static check, expr: %s", expr.Serialize().get());
return TriBool::kUnknown;
}
if (GetCurShapeEnvContext()->HasSymbolInfo(expr) == TriBool::kTrue) {
GELOGI("Find check info of expr: %s, no need simplify guard", SymbolicUtils::ToString(expr).c_str());
return TriBool::kTrue;
}
const auto simplify_expr = expr.Simplify();
value = false;
if (simplify_expr.IsConstExpr()) {
GE_ASSERT_TRUE(simplify_expr.GetConstValue(value));
return value ? TriBool::kTrue : TriBool::kFalse;
}
return GetCurShapeEnvContext()->HasSymbolInfo(simplify_expr);
}
}