* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <any>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <map>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
#include "core/any_cast.h"
#include "core/dtype.h"
#include "core/error.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/kind_traits.h"
#include "ir/memref.h"
#include "ir/program.h"
#include "ir/reflection/field_visitor.h"
#include "ir/scalar_expr.h"
#include "ir/span.h"
#include "ir/stmt.h"
#include "ir/transforms/printer.h"
#include "ir/transforms/structural_comparison.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
namespace {
bool AreForSyntaxScalarDtypesEquivalent(const DataType& lhs, const DataType& rhs)
{
return (lhs == DataType::INT64 && rhs == DataType::INDEX) || (lhs == DataType::INDEX && rhs == DataType::INT64);
}
bool AreDoublesEquivalent(double lhs, double rhs) { return std::abs(lhs - rhs) <= 1e-10; }
}
* \brief Unified structural equality checker for IR nodes
*
* Template parameter controls behavior on mismatch:
* - AssertMode=false: Returns false (for structural_equal)
* - AssertMode=true: Throws ValueError with detailed error message (for assert_structural_equal)
*
* This class is not part of the public API - use structural_equal() or assert_structural_equal().
*
* Implements the FieldIterator visitor interface for generic field-based comparison.
* Uses the dual-node Visit overload which calls visitor methods with two field arguments.
*/
template <bool AssertMode>
class StructuralEqualImpl {
public:
using ResultType = bool;
explicit StructuralEqualImpl(bool enable_auto_mapping) : enable_auto_mapping_(enable_auto_mapping) {}
bool operator()(const IRNodePtr& lhs, const IRNodePtr& rhs)
{
if constexpr (AssertMode) {
Equal(lhs, rhs);
return true;
} else {
return Equal(lhs, rhs);
}
}
bool operator()(const TypePtr& lhs, const TypePtr& rhs)
{
if constexpr (AssertMode) {
EqualType(lhs, rhs);
return true;
} else {
return EqualType(lhs, rhs);
}
}
[[nodiscard]] ResultType InitResult() const { return true; }
template <typename IRNodePtrType>
ResultType VisitIRNodeField(const IRNodePtrType& lhs, const IRNodePtrType& rhs)
{
INTERNAL_CHECK(lhs) << "structural_equal encountered null lhs IR node field";
INTERNAL_CHECK_SPAN(rhs, lhs->span_) << "structural_equal encountered null rhs IR node field";
return Equal(lhs, rhs);
}
template <typename IRNodePtrType>
ResultType VisitIRNodeField(const std::optional<IRNodePtrType>& lhs, const std::optional<IRNodePtrType>& rhs)
{
if (!lhs.has_value() && !rhs.has_value()) {
return true;
}
if (!lhs.has_value() || !rhs.has_value()) {
if constexpr (AssertMode) {
ThrowMismatch(
"Optional field presence mismatch", lhs.has_value() ? *lhs : IRNodePtr(),
rhs.has_value() ? *rhs : IRNodePtr(), lhs.has_value() ? "has value" : "nullopt",
rhs.has_value() ? "has value" : "nullopt");
}
return false;
}
if (!*lhs && !*rhs) {
return true;
}
if (!*lhs || !*rhs) {
if constexpr (AssertMode) {
ThrowMismatch(
"Optional field nullptr mismatch", *lhs, *rhs, *lhs ? "has value" : "nullptr",
*rhs ? "has value" : "nullptr");
}
return false;
}
return Equal(*lhs, *rhs);
}
template <typename IRNodePtrType>
ResultType VisitIRNodeVectorField(const std::vector<IRNodePtrType>& lhs, const std::vector<IRNodePtrType>& rhs)
{
if (lhs.size() != rhs.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Vector size mismatch (" << lhs.size() << " items != " << rhs.size() << " items)";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs.size(); ++i) {
INTERNAL_CHECK(lhs[i]) << "structural_equal encountered null lhs IR node in vector at index " << i;
INTERNAL_CHECK_SPAN(rhs[i], lhs[i]->span_)
<< "structural_equal encountered null rhs IR node in vector at index " << i;
if constexpr (AssertMode) {
std::ostringstream index_str;
index_str << "[" << i << "]";
path_.emplace_back(index_str.str());
}
if (!Equal(lhs[i], rhs[i])) {
if constexpr (AssertMode) {
path_.pop_back();
}
return false;
}
if constexpr (AssertMode) {
path_.pop_back();
}
}
return true;
}
template <typename MapType, typename EntryChecker, typename KeyFormatter>
ResultType VisitIRNodeMapFieldImpl(
const MapType& lhs, const MapType& rhs, EntryChecker&& check_entry, KeyFormatter&& format_key)
{
if (lhs.size() != rhs.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Map size mismatch (" << lhs.size() << " items != " << rhs.size() << " items)";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
auto lhs_it = lhs.begin();
auto rhs_it = rhs.begin();
while (lhs_it != lhs.end()) {
check_entry(lhs_it, rhs_it);
std::string lhs_key = format_key(lhs_it->first);
std::string rhs_key = format_key(rhs_it->first);
if (lhs_key != rhs_key) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Map key mismatch ('" << lhs_key << "' != '" << rhs_key << "')";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if constexpr (AssertMode) {
path_.emplace_back("['" + lhs_key + "']");
}
if (!Equal(lhs_it->second, rhs_it->second)) {
if constexpr (AssertMode) {
path_.pop_back();
}
return false;
}
if constexpr (AssertMode) {
path_.pop_back();
}
++lhs_it;
++rhs_it;
}
return true;
}
template <typename KeyType, typename ValueType, typename Compare>
ResultType VisitIRNodeMapField(
const std::map<KeyType, ValueType, Compare>& lhs, const std::map<KeyType, ValueType, Compare>& rhs)
{
return VisitIRNodeMapFieldImpl(
lhs, rhs,
[](const auto& lhs_it, const auto& rhs_it) {
INTERNAL_CHECK(lhs_it->first) << "structural_equal encountered null lhs key in map";
INTERNAL_CHECK(lhs_it->second) << "structural_equal encountered null lhs value in map";
INTERNAL_CHECK(rhs_it->first) << "structural_equal encountered null rhs key in map";
INTERNAL_CHECK_SPAN(rhs_it->second, lhs_it->second->span_)
<< "structural_equal encountered null rhs value in map";
},
[](const auto& key) { return key->name_; });
}
template <typename ValueType>
ResultType VisitIRNodeMapField(
const std::map<std::string, ValueType>& lhs, const std::map<std::string, ValueType>& rhs)
{
return VisitIRNodeMapFieldImpl(
lhs, rhs,
[](const auto& lhs_it, const auto& rhs_it) {
INTERNAL_CHECK(lhs_it->second) << "structural_equal encountered null lhs value in map";
INTERNAL_CHECK_SPAN(rhs_it->second, lhs_it->second->span_)
<< "structural_equal encountered null rhs value in map";
},
[](const std::string& key) { return key; });
}
ResultType VisitLeafField(const int& lhs, const int& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Integer value mismatch (" << lhs << " != " << rhs << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const int64_t& lhs, const int64_t& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "int64_t value mismatch (" << lhs << " != " << rhs << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const uint64_t& lhs, const uint64_t& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "uint64_t value mismatch (" << lhs << " != " << rhs << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const double& lhs, const double& rhs)
{
if (!AreDoublesEquivalent(lhs, rhs)) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "double value mismatch (" << lhs << " != " << rhs << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const std::string& lhs, const std::string& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "String value mismatch (\"" << lhs << "\" != \"" << rhs << "\")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const DataType& lhs, const DataType& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "DataType mismatch (" << lhs.ToString() << " != " << rhs.ToString() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const FunctionType& lhs, const FunctionType& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "FunctionType mismatch (" << FunctionTypeToString(lhs) << " != " << FunctionTypeToString(rhs)
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
[[nodiscard]] ResultType VisitLeafField(const SectionKind& lhs, const SectionKind& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "SectionKind mismatch (" << SectionKindToString(lhs) << " != " << SectionKindToString(rhs)
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(
const std::vector<std::pair<std::string, std::any>>& lhs,
const std::vector<std::pair<std::string, std::any>>& rhs)
{
if (lhs.size() != rhs.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Kwargs size mismatch (" << lhs.size() << " != " << rhs.size() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs.size(); ++i) {
if (lhs[i].first != rhs[i].first) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Kwargs key mismatch at index " << i << " ('" << lhs[i].first << "' != '" << rhs[i].first
<< "')";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
const auto& lhs_val = lhs[i].second;
const auto& rhs_val = rhs[i].second;
if (lhs_val.type() != rhs_val.type()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Kwargs value type mismatch for key '" << lhs[i].first << "'";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
bool values_equal = true;
if (lhs_val.type() == typeid(int)) {
values_equal =
(AnyCast<int>(lhs_val, "comparing kwarg: " + lhs[i].first) ==
AnyCast<int>(rhs_val, "comparing kwarg: " + lhs[i].first));
} else if (lhs_val.type() == typeid(bool)) {
values_equal =
(AnyCast<bool>(lhs_val, "comparing kwarg: " + lhs[i].first) ==
AnyCast<bool>(rhs_val, "comparing kwarg: " + lhs[i].first));
} else if (lhs_val.type() == typeid(std::string)) {
values_equal =
(AnyCast<std::string>(lhs_val, "comparing kwarg: " + lhs[i].first) ==
AnyCast<std::string>(rhs_val, "comparing kwarg: " + lhs[i].first));
} else if (lhs_val.type() == typeid(double)) {
values_equal = AreDoublesEquivalent(
AnyCast<double>(lhs_val, "comparing kwarg: " + lhs[i].first),
AnyCast<double>(rhs_val, "comparing kwarg: " + lhs[i].first));
} else if (lhs_val.type() == typeid(DataType)) {
values_equal =
(AnyCast<DataType>(lhs_val, "comparing kwarg: " + lhs[i].first) ==
AnyCast<DataType>(rhs_val, "comparing kwarg: " + lhs[i].first));
} else if (lhs_val.type() == typeid(std::vector<int>)) {
values_equal =
(AnyCast<std::vector<int>>(lhs_val, "comparing kwarg: " + lhs[i].first) ==
AnyCast<std::vector<int>>(rhs_val, "comparing kwarg: " + lhs[i].first));
}
if (!values_equal) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Kwargs value mismatch for key '" << lhs[i].first << "'";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
}
return true;
}
ResultType VisitLeafField(const MemorySpace& lhs, const MemorySpace& rhs)
{
if (lhs != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "MemorySpace mismatch (" << MemorySpaceToString(lhs) << " != " << MemorySpaceToString(rhs)
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
}
ResultType VisitLeafField(const TypePtr& lhs, const TypePtr& rhs) { return EqualType(lhs, rhs); }
ResultType VisitLeafField(const std::vector<TypePtr>& lhs, const std::vector<TypePtr>& rhs)
{
if (lhs.size() != rhs.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Type vector size mismatch (" << lhs.size() << " types != " << rhs.size() << " types)";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs.size(); ++i) {
INTERNAL_CHECK(lhs[i]) << "structural_equal encountered null lhs TypePtr in vector at index " << i;
INTERNAL_CHECK(rhs[i]) << "structural_equal encountered null rhs TypePtr in vector at index " << i;
if (!EqualType(lhs[i], rhs[i]))
return false;
}
return true;
}
[[nodiscard]] ResultType VisitLeafField(const Span& lhs, const Span&) const
{
INTERNAL_CHECK_SPAN(false, lhs) << "structural_equal should not visit Span field";
return true;
}
ResultType VisitLeafField(const std::vector<IterArgPtr>& lhs, const std::vector<IterArgPtr>& rhs)
{
if (lhs.size() != rhs.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "IterArg vector size mismatch (" << lhs.size() << " != " << rhs.size() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs.size(); ++i) {
INTERNAL_CHECK(lhs[i]) << "structural_equal encountered null lhs IterArgPtr in vector at index " << i;
INTERNAL_CHECK(rhs[i]) << "structural_equal encountered null rhs IterArgPtr in vector at index " << i;
if (!EqualIterArg(lhs[i], rhs[i])) {
return false;
}
}
return true;
}
template <typename FVisitOp>
void VisitIgnoreField([[maybe_unused]] FVisitOp&& visit_op)
{
}
template <typename FVisitOp>
void VisitDefField(FVisitOp&& visit_op)
{
bool enable_auto_mapping = true;
std::swap(enable_auto_mapping, enable_auto_mapping_);
visit_op();
std::swap(enable_auto_mapping, enable_auto_mapping_);
}
template <typename FVisitOp>
void VisitUsualField(FVisitOp&& visit_op)
{
visit_op();
}
void PushFieldName(const char* name)
{
if (transparent_depth_ == 0) {
field_name_stack_.emplace_back(name);
}
if constexpr (AssertMode) {
if (transparent_depth_ == 0) {
path_.emplace_back(name);
}
}
}
void PopFieldName()
{
if (transparent_depth_ == 0) {
field_name_stack_.pop_back();
}
if constexpr (AssertMode) {
if (transparent_depth_ == 0) {
path_.pop_back();
}
}
}
template <typename Desc>
void CombineResult(ResultType& accumulator, ResultType field_result, [[maybe_unused]] const Desc& desc)
{
accumulator = accumulator && field_result;
}
private:
bool Equal(const IRNodePtr& lhs, const IRNodePtr& rhs);
bool EqualVar(const VarPtr& lhs, const VarPtr& rhs);
bool EqualMemRef(const MemRefPtr& lhs, const MemRefPtr& rhs);
bool EqualIterArg(const IterArgPtr& lhs, const IterArgPtr& rhs);
bool EqualType(const TypePtr& lhs, const TypePtr& rhs);
bool IsLoopVarFieldContext() const { return !field_name_stack_.empty() && field_name_stack_.back() == "loop_var"; }
bool IsConstIntTypeContext() const
{
return !node_type_stack_.empty() && node_type_stack_.back() == "ConstInt" && !field_name_stack_.empty() &&
field_name_stack_.back() == "type";
}
* \brief Generic field-based equality check for IR nodes using FieldIterator
*
* Uses the dual-node Visit overload which passes two fields to each visitor method.
*
* \tparam NodePtr Shared pointer type to the node
* \param lhs_op Left-hand side node
* \param rhs_op Right-hand side node
* \return true if all fields are equal
*/
template <typename NodePtr>
bool EqualWithFields(const NodePtr& lhs_op, const NodePtr& rhs_op)
{
using NodeType = typename NodePtr::element_type;
auto descriptors = NodeType::GetFieldDescriptors();
return std::apply(
[&](auto&&... descs) {
return reflection::FieldIterator<NodeType, StructuralEqualImpl<AssertMode>, decltype(descs)...>::Visit(
*lhs_op, *rhs_op, *this, descs...);
},
descriptors);
}
void ThrowMismatch(
const std::string& reason, const IRNodePtr& lhs, const IRNodePtr& rhs, const std::string& lhs_desc = "",
const std::string& rhs_desc = "")
{
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Structural equality assertion failed";
if (!path_.empty()) {
msg << " at: ";
for (size_t i = 0; i < path_.size(); ++i) {
msg << path_[i];
if (i < path_.size() - 1 && path_[i + 1][0] != '[') {
msg << ".";
}
}
}
if (!node_type_stack_.empty()) {
msg << "\nNode stack: ";
for (size_t i = 0; i < node_type_stack_.size(); ++i) {
if (i > 0) {
msg << " > ";
}
msg << node_type_stack_[i];
}
}
msg << "\n\n";
if (lhs || rhs) {
msg << "Left-hand side:\n";
if (lhs) {
std::string lhs_str = PythonPrint(lhs, "pl");
std::istringstream iss(lhs_str);
std::string line;
while (std::getline(iss, line)) {
msg << " " << line << "\n";
}
} else {
msg << " (null)\n";
}
msg << "\nRight-hand side:\n";
if (rhs) {
std::string rhs_str = PythonPrint(rhs, "pl");
std::istringstream iss(rhs_str);
std::string line;
while (std::getline(iss, line)) {
msg << " " << line << "\n";
}
} else {
msg << " (null)\n";
}
msg << "\n";
} else if (!lhs_desc.empty() || !rhs_desc.empty()) {
msg << "Left: " << lhs_desc << "\n";
msg << "Right: " << rhs_desc << "\n\n";
}
msg << "Reason: " << reason;
throw pypto::ir::ValueError(msg.str());
}
}
bool enable_auto_mapping_;
std::unordered_map<VarPtr, VarPtr> lhs_to_rhs_var_map_;
std::unordered_map<VarPtr, VarPtr> rhs_to_lhs_var_map_;
std::vector<std::string> path_;
std::vector<std::string> field_name_stack_;
std::vector<std::string> node_type_stack_;
int transparent_depth_ = 0;
};
#define EQUAL_DISPATCH(Type) \
if (auto lhs_##Type = As<Type>(lhs)) { \
auto rhs_##Type = As<Type>(rhs); \
node_type_stack_.emplace_back(#Type); \
int saved_depth = transparent_depth_; \
transparent_depth_ = 0; \
bool result = rhs_##Type && EqualWithFields(lhs_##Type, rhs_##Type); \
transparent_depth_ = saved_depth; \
node_type_stack_.pop_back(); \
return result; \
}
#define EQUAL_DISPATCH_TRANSPARENT(Type) \
if (auto lhs_##Type = As<Type>(lhs)) { \
transparent_depth_++; \
auto rhs_##Type = As<Type>(rhs); \
node_type_stack_.emplace_back(#Type); \
bool result = rhs_##Type && EqualWithFields(lhs_##Type, rhs_##Type); \
node_type_stack_.pop_back(); \
transparent_depth_--; \
return result; \
}
template <bool AssertMode>
bool StructuralEqualImpl<AssertMode>::Equal(const IRNodePtr& lhs, const IRNodePtr& rhs)
{
if (lhs.get() == rhs.get())
return true;
if (!lhs || !rhs) {
if constexpr (AssertMode)
ThrowMismatch("One node is null, the other is not", lhs, rhs);
return false;
}
if (lhs->TypeName() != rhs->TypeName()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Node type mismatch (" << lhs->TypeName() << " != " << rhs->TypeName() << ")";
ThrowMismatch(msg.str(), lhs, rhs);
}
return false;
}
if (auto lhs_memref = As<MemRef>(lhs)) {
node_type_stack_.emplace_back("MemRef");
auto rhs_memref = std::static_pointer_cast<const MemRef>(rhs);
bool result = rhs_memref && EqualMemRef(lhs_memref, rhs_memref);
node_type_stack_.pop_back();
return result;
}
if (auto lhs_var = As<Var>(lhs)) {
node_type_stack_.emplace_back("Var");
bool result = EqualVar(lhs_var, std::static_pointer_cast<const Var>(rhs));
node_type_stack_.pop_back();
return result;
}
EQUAL_DISPATCH(ConstInt)
EQUAL_DISPATCH(ConstFloat)
EQUAL_DISPATCH(ConstBool)
EQUAL_DISPATCH(Call)
EQUAL_DISPATCH(MakeTuple)
EQUAL_DISPATCH(GetItemExpr)
EQUAL_DISPATCH(BinaryExpr)
EQUAL_DISPATCH(UnaryExpr)
EQUAL_DISPATCH(AssignStmt)
EQUAL_DISPATCH(IfStmt)
EQUAL_DISPATCH(YieldStmt)
EQUAL_DISPATCH(ReturnStmt)
EQUAL_DISPATCH(ForStmt)
EQUAL_DISPATCH(WhileStmt)
EQUAL_DISPATCH(SectionStmt)
EQUAL_DISPATCH_TRANSPARENT(SeqStmts)
EQUAL_DISPATCH(EvalStmt)
EQUAL_DISPATCH(BreakStmt)
EQUAL_DISPATCH(ContinueStmt)
EQUAL_DISPATCH(Function)
EQUAL_DISPATCH_TRANSPARENT(Program)
throw pypto::ir::TypeError("Unknown IR node type in StructuralEqualImpl::Equal: " + lhs->TypeName());
}
#undef EQUAL_DISPATCH
#undef EQUAL_DISPATCH_TRANSPARENT
template <bool AssertMode>
bool StructuralEqualImpl<AssertMode>::EqualType(const TypePtr& lhs, const TypePtr& rhs)
{
if (lhs->TypeName() != rhs->TypeName()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Type name mismatch (" << lhs->TypeName() << " != " << rhs->TypeName() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (auto lhs_scalar = As<ScalarType>(lhs)) {
auto rhs_scalar = As<ScalarType>(rhs);
if (!rhs_scalar) {
if constexpr (AssertMode) {
ThrowMismatch("Type cast failed for ScalarType", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if ((IsLoopVarFieldContext() || IsConstIntTypeContext()) &&
AreForSyntaxScalarDtypesEquivalent(lhs_scalar->dtype_, rhs_scalar->dtype_)) {
return true;
}
if (lhs_scalar->dtype_ != rhs_scalar->dtype_) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "ScalarType dtype mismatch (" << lhs_scalar->dtype_.ToString()
<< " != " << rhs_scalar->dtype_.ToString() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
return true;
} else if (auto lhs_tensor = As<TensorType>(lhs)) {
auto rhs_tensor = As<TensorType>(rhs);
if (!rhs_tensor) {
if constexpr (AssertMode) {
ThrowMismatch("Type cast failed for TensorType", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tensor->dtype_ != rhs_tensor->dtype_) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TensorType dtype mismatch (" << lhs_tensor->dtype_.ToString()
<< " != " << rhs_tensor->dtype_.ToString() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tensor->shape_.size() != rhs_tensor->shape_.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TensorType shape rank mismatch (" << lhs_tensor->shape_.size()
<< " != " << rhs_tensor->shape_.size() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs_tensor->shape_.size(); ++i) {
if (!Equal(lhs_tensor->shape_[i], rhs_tensor->shape_[i]))
return false;
}
return true;
} else if (auto lhs_tile = As<TileType>(lhs)) {
auto rhs_tile = As<TileType>(rhs);
if (!rhs_tile) {
if constexpr (AssertMode) {
ThrowMismatch("Type cast failed for TileType", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tile->dtype_ != rhs_tile->dtype_) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TileType dtype mismatch (" << lhs_tile->dtype_.ToString()
<< " != " << rhs_tile->dtype_.ToString() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tile->shape_.size() != rhs_tile->shape_.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TileType shape rank mismatch (" << lhs_tile->shape_.size() << " != " << rhs_tile->shape_.size()
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs_tile->shape_.size(); ++i) {
if (!Equal(lhs_tile->shape_[i], rhs_tile->shape_[i]))
return false;
}
if (lhs_tile->tileView_.has_value() != rhs_tile->tileView_.has_value()) {
if constexpr (AssertMode) {
ThrowMismatch("TileType tile_view presence mismatch", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tile->tileView_.has_value()) {
const auto& lhs_tv = lhs_tile->tileView_.value();
const auto& rhs_tv = rhs_tile->tileView_.value();
if (lhs_tv.validShape.size() != rhs_tv.validShape.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TileView valid_shape size mismatch (" << lhs_tv.validShape.size()
<< " != " << rhs_tv.validShape.size() << ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs_tv.validShape.size(); ++i) {
if (!Equal(lhs_tv.validShape[i], rhs_tv.validShape[i]))
return false;
}
if (lhs_tv.stride.size() != rhs_tv.stride.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TileView stride size mismatch (" << lhs_tv.stride.size() << " != " << rhs_tv.stride.size()
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs_tv.stride.size(); ++i) {
if (!Equal(lhs_tv.stride[i], rhs_tv.stride[i]))
return false;
}
if (!Equal(lhs_tv.startOffset, rhs_tv.startOffset))
return false;
}
if (lhs_tile->hardwareInfo_.has_value() != rhs_tile->hardwareInfo_.has_value()) {
if constexpr (AssertMode) {
ThrowMismatch("TileType hardware_info presence mismatch", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tile->hardwareInfo_.has_value()) {
const auto& lhs_hw = lhs_tile->hardwareInfo_.value();
const auto& rhs_hw = rhs_tile->hardwareInfo_.value();
if (lhs_hw.blayout != rhs_hw.blayout) {
if constexpr (AssertMode) {
ThrowMismatch("HardwareInfo blayout mismatch", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_hw.slayout != rhs_hw.slayout) {
if constexpr (AssertMode) {
ThrowMismatch("HardwareInfo slayout mismatch", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_hw.fractal != rhs_hw.fractal) {
if constexpr (AssertMode) {
ThrowMismatch("HardwareInfo fractal mismatch", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_hw.pad != rhs_hw.pad) {
if constexpr (AssertMode) {
ThrowMismatch("HardwareInfo pad mismatch", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
}
return true;
} else if (auto lhs_tuple = As<TupleType>(lhs)) {
auto rhs_tuple = As<TupleType>(rhs);
if (!rhs_tuple) {
if constexpr (AssertMode) {
ThrowMismatch("Type cast failed for TupleType", IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
if (lhs_tuple->types_.size() != rhs_tuple->types_.size()) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "TupleType size mismatch (" << lhs_tuple->types_.size() << " != " << rhs_tuple->types_.size()
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
for (size_t i = 0; i < lhs_tuple->types_.size(); ++i) {
if (!EqualType(lhs_tuple->types_[i], rhs_tuple->types_[i]))
return false;
}
return true;
} else if (IsA<MemRefType>(lhs) || IsA<UnknownType>(lhs)) {
return true;
} else if (IsA<LogicalTensorType>(lhs) && IsA<LogicalTensorType>(rhs)) {
return true;
}
INTERNAL_UNREACHABLE << "EqualType encountered unhandled Type: " << lhs->TypeName();
return false;
}
template <bool AssertMode>
bool StructuralEqualImpl<AssertMode>::EqualVar(const VarPtr& lhs, const VarPtr& rhs)
{
if (!enable_auto_mapping_) {
auto lhs_it = lhs_to_rhs_var_map_.find(lhs);
auto rhs_it = rhs_to_lhs_var_map_.find(rhs);
if (lhs_it != lhs_to_rhs_var_map_.end() && rhs_it != rhs_to_lhs_var_map_.end()) {
if (lhs_it->second != rhs || rhs_it->second != lhs) {
if constexpr (AssertMode) {
ThrowMismatch(
"Variable mapping inconsistent (without auto-mapping)",
std::static_pointer_cast<const IRNode>(lhs), std::static_pointer_cast<const IRNode>(rhs),
"var " + lhs->name_, "var " + rhs->name_);
}
return false;
}
return true;
}
if (lhs.get() != rhs.get()) {
if constexpr (AssertMode) {
ThrowMismatch(
"Variable pointer mismatch (without auto-mapping)", std::static_pointer_cast<const IRNode>(lhs),
std::static_pointer_cast<const IRNode>(rhs), "var " + lhs->name_, "var " + rhs->name_);
}
return false;
}
return true;
}
if (!EqualType(lhs->GetType(), rhs->GetType())) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Variable type mismatch (" << lhs->GetType()->TypeName() << " != " << rhs->GetType()->TypeName()
<< ")";
ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", "");
}
return false;
}
auto it = lhs_to_rhs_var_map_.find(lhs);
if (it != lhs_to_rhs_var_map_.end()) {
if (it->second != rhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Variable mapping inconsistent ('" << lhs->name_ << "' cannot map to both '" << it->second->name_
<< "' and '" << rhs->name_ << "')";
ThrowMismatch(
msg.str(), std::static_pointer_cast<const IRNode>(lhs),
std::static_pointer_cast<const IRNode>(rhs));
}
return false;
}
return true;
}
auto rhs_it = rhs_to_lhs_var_map_.find(rhs);
if (rhs_it != rhs_to_lhs_var_map_.end() && rhs_it->second != lhs) {
if constexpr (AssertMode) {
std::ostringstream msg;
msg << "Variable mapping inconsistent ('" << rhs->name_ << "' is already mapped from '"
<< rhs_it->second->name_ << "', cannot map from '" << lhs->name_ << "')";
ThrowMismatch(
msg.str(), std::static_pointer_cast<const IRNode>(lhs), std::static_pointer_cast<const IRNode>(rhs));
}
return false;
}
lhs_to_rhs_var_map_[lhs] = rhs;
rhs_to_lhs_var_map_[rhs] = lhs;
return true;
}
template <bool AssertMode>
bool StructuralEqualImpl<AssertMode>::EqualMemRef(const MemRefPtr& lhs, const MemRefPtr& rhs)
{
if (!MemRef::SameAllocation(lhs, rhs)) {
if constexpr (AssertMode) {
ThrowMismatch(
"MemRef base mismatch", std::static_pointer_cast<const IRNode>(lhs),
std::static_pointer_cast<const IRNode>(rhs));
}
return false;
}
return true;
}
template <bool AssertMode>
bool StructuralEqualImpl<AssertMode>::EqualIterArg(const IterArgPtr& lhs, const IterArgPtr& rhs)
{
if (!EqualVar(lhs->iterVar_, rhs->iterVar_)) {
return false;
}
if (!Equal(lhs->initValue_, rhs->initValue_)) {
if constexpr (AssertMode) {
ThrowMismatch(
"IterArg initValue mismatch", std::static_pointer_cast<const IRNode>(lhs->initValue_),
std::static_pointer_cast<const IRNode>(rhs->initValue_));
}
return false;
}
return true;
}
template class StructuralEqualImpl<false>;
template class StructuralEqualImpl<true>;
using StructuralEqual = StructuralEqualImpl<false>;
using StructuralEqualAssert = StructuralEqualImpl<true>;
bool structural_equal(const IRNodePtr& lhs, const IRNodePtr& rhs, bool enable_auto_mapping)
{
StructuralEqual checker(enable_auto_mapping);
return checker(lhs, rhs);
}
bool structural_equal(const TypePtr& lhs, const TypePtr& rhs, bool enable_auto_mapping)
{
StructuralEqual checker(enable_auto_mapping);
return checker(lhs, rhs);
}
void assert_structural_equal(const IRNodePtr& lhs, const IRNodePtr& rhs, bool enable_auto_mapping)
{
StructuralEqualAssert checker(enable_auto_mapping);
checker(lhs, rhs);
}
void assert_structural_equal(const TypePtr& lhs, const TypePtr& rhs, bool enable_auto_mapping)
{
StructuralEqualAssert checker(enable_auto_mapping);
checker(lhs, rhs);
}
}
}