* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_ir.cpp
* \brief Unit tests ported from Python test_ir.py — string representations, IRCHECK macros, and structural comparison
*/
#include "gtest/gtest.h"
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "core/dtype.h"
#include "core/error.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/memref.h"
#include "ir/program.h"
#include "ir/scalar_expr.h"
#include "ir/stmt.h"
#include "ir/transforms/printer.h"
#include "ir/transforms/structural_comparison.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
static TypePtr Scalar(DataType dt) { return std::make_shared<ScalarType>(dt); }
class IRCheckTest : public testing::Test {};
TEST_F(IRCheckTest, TestCheckPass) { IRCHECK(true) << "should not throw"; }
TEST_F(IRCheckTest, TestCheckFail) { ASSERT_THROW(IRCHECK(false) << "test check message", ValueError); }
TEST_F(IRCheckTest, TestInternalCheckPass) { INTERNAL_CHECK(true) << "should not throw"; }
TEST_F(IRCheckTest, TestInternalCheckFail)
{
ASSERT_THROW(INTERNAL_CHECK(false) << "test internal check message", InternalError);
}
class IRTypeStrTest : public testing::Test {};
TEST_F(IRTypeStrTest, TestUnknownTypeStr)
{
auto ut = std::make_shared<UnknownType>();
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const Type>(ut)), "pl.Unknown");
}
TEST_F(IRTypeStrTest, TestScalarTypeStr)
{
auto st = std::make_shared<ScalarType>(DataType::INT32);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const Type>(st)), "pl.Scalar[pl.INT32]");
}
TEST_F(IRTypeStrTest, TestTensorTypeStr)
{
auto dim1 = std::make_shared<ConstInt>(16, DataType::INT64, Span::Unknown());
auto dim2 = std::make_shared<ConstInt>(32, DataType::INT64, Span::Unknown());
auto tt = std::make_shared<TensorType>(std::vector<ExprPtr>{dim1, dim2}, DataType::FP32);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const Type>(tt)), "pl.Tensor[[16, 32], pl.FP32]");
}
TEST_F(IRTypeStrTest, TestTensorTypeWithMemRefStr)
{
auto dim1 = std::make_shared<ConstInt>(16, DataType::INT64, Span::Unknown());
auto dim2 = std::make_shared<ConstInt>(32, DataType::INT64, Span::Unknown());
auto offset = std::make_shared<ConstInt>(0, DataType::INT64, Span::Unknown());
MemRefPtr memref = std::make_shared<MemRef>(MemorySpace::DDR, offset, 1024);
auto tt = std::make_shared<TensorType>(std::vector<ExprPtr>{dim1, dim2}, DataType::FP16, memref);
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const Type>(tt)),
"pl.Tensor[[16, 32], pl.FP16, pl.MemRef(pl.MemorySpace.DDR, 0, 1024)]");
}
TEST_F(IRTypeStrTest, TestTupleTypeStr)
{
auto t1 = std::make_shared<ScalarType>(DataType::INT32);
auto t2 = std::make_shared<ScalarType>(DataType::FP32);
auto tup = std::make_shared<TupleType>(std::vector<TypePtr>{t1, t2});
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const Type>(tup)), "pl.Tuple[pl.Scalar[pl.INT32], pl.Scalar[pl.FP32]]");
}
TEST_F(IRTypeStrTest, TestPtrTypeStr)
{
auto pt = std::make_shared<PtrType>(DataType::FP32);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const Type>(pt)), "pl.Ptr");
}
class IRExprStrTest : public testing::Test {
protected:
Span sp = Span("test", 1, 1);
};
TEST_F(IRExprStrTest, TestConstIntStr)
{
auto ci = std::make_shared<ConstInt>(42, DataType::INT32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(ci)), "42");
}
TEST_F(IRExprStrTest, TestConstFloatStr)
{
auto cf = std::make_shared<ConstFloat>(3.14, DataType::FP32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(cf)), "3.14");
}
TEST_F(IRExprStrTest, TestConstBoolStr)
{
auto cb = std::make_shared<ConstBool>(true, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(cb)), "True");
}
TEST_F(IRExprStrTest, TestVarStr)
{
auto var = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(var)), "x");
}
TEST_F(IRExprStrTest, TestBinaryOpsStr)
{
auto a = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(2, DataType::INT32, sp);
struct BinOpTest {
std::string opName;
ExprPtr expr;
std::string expected;
};
std::vector<BinOpTest> tests = {
{"Add", std::make_shared<Add>(a, b, DataType::INT32, sp), "1 + 2"},
{"Sub", std::make_shared<Sub>(a, b, DataType::INT32, sp), "1 - 2"},
{"Mul", std::make_shared<Mul>(a, b, DataType::INT32, sp), "1 * 2"},
{"FloorDiv", std::make_shared<FloorDiv>(a, b, DataType::INT32, sp), "1 // 2"},
{"FloatDiv", std::make_shared<FloatDiv>(a, b, DataType::INT32, sp), "1 / 2"},
{"FloorMod", std::make_shared<FloorMod>(a, b, DataType::INT32, sp), "1 % 2"},
{"Pow", std::make_shared<Pow>(a, b, DataType::INT32, sp), "1 ** 2"},
{"Eq", std::make_shared<Eq>(a, b, DataType::INT32, sp), "1 == 2"},
{"Ne", std::make_shared<Ne>(a, b, DataType::INT32, sp), "1 != 2"},
{"Lt", std::make_shared<Lt>(a, b, DataType::INT32, sp), "1 < 2"},
{"Le", std::make_shared<Le>(a, b, DataType::INT32, sp), "1 <= 2"},
{"Gt", std::make_shared<Gt>(a, b, DataType::INT32, sp), "1 > 2"},
{"Ge", std::make_shared<Ge>(a, b, DataType::INT32, sp), "1 >= 2"},
{"And", std::make_shared<And>(a, b, DataType::INT32, sp), "1 and 2"},
{"Or", std::make_shared<Or>(a, b, DataType::INT32, sp), "1 or 2"},
{"Xor", std::make_shared<Xor>(a, b, DataType::INT32, sp), "1 xor 2"},
{"BitAnd", std::make_shared<BitAnd>(a, b, DataType::INT32, sp), "1 & 2"},
{"BitOr", std::make_shared<BitOr>(a, b, DataType::INT32, sp), "1 | 2"},
{"BitXor", std::make_shared<BitXor>(a, b, DataType::INT32, sp), "1 ^ 2"},
{"BitShiftLeft", std::make_shared<BitShiftLeft>(a, b, DataType::INT32, sp), "1 << 2"},
{"BitShiftRight", std::make_shared<BitShiftRight>(a, b, DataType::INT32, sp), "1 >> 2"},
};
for (const auto& t : tests) {
auto node = std::static_pointer_cast<const IRNode>(t.expr);
ASSERT_EQ(PythonPrint(node), t.expected) << "Failed for op: " << t.opName;
}
}
TEST_F(IRExprStrTest, TestMinMaxStr)
{
auto a = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(2, DataType::INT32, sp);
auto minExpr = std::make_shared<Min>(a, b, DataType::INT32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(minExpr)), "pl.min(1, 2)");
auto maxExpr = std::make_shared<Max>(a, b, DataType::INT32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(maxExpr)), "pl.max(1, 2)");
}
TEST_F(IRExprStrTest, TestUnaryOpsStr)
{
auto a = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto cb = std::make_shared<ConstBool>(true, sp);
auto neg = std::make_shared<Neg>(a, DataType::INT32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(neg)), "-1");
auto notExpr = std::make_shared<Not>(cb, DataType::BOOL, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(notExpr)), "not True");
auto absExpr = std::make_shared<Abs>(a, DataType::INT32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(absExpr)), "pl.abs(1)");
auto cast = std::make_shared<Cast>(a, DataType::FP32, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(cast)), "pl.cast(1, pl.FP32)");
}
TEST_F(IRExprStrTest, TestMakeTupleAndGetItemStr)
{
auto a = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(2, DataType::INT32, sp);
auto mt = std::make_shared<MakeTuple>(std::vector<ExprPtr>{a, b}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(mt)), "[1, 2]");
auto idx = std::make_shared<ConstInt>(0, DataType::INDEX, sp);
auto tgi = std::make_shared<GetItemExpr>(mt, idx, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(tgi)), "[1, 2][0]");
}
TEST_F(IRExprStrTest, TestCallStr)
{
auto a = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(2, DataType::INT32, sp);
auto call = std::make_shared<Call>("my_op", std::vector<ExprPtr>{a, b}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(call)), "pl.call @my_op(1, 2)");
}
TEST_F(IRExprStrTest, TestMemRefStr)
{
auto offset = std::make_shared<ConstInt>(0, DataType::INT64, sp);
auto memref = std::make_shared<MemRef>(MemorySpace::Vec, offset, 2048, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(memref)), "pl.MemRef(pl.MemorySpace.Vec, 0, 2048)");
}
class IRStmtStrTest : public testing::Test {
protected:
Span sp = Span("test", 1, 1);
TypePtr st = Scalar(DataType::INT32);
};
TEST_F(IRStmtStrTest, TestAssignStmtStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto val = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto assign = std::make_shared<AssignStmt>(x, val, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(assign)), "x: pl.Scalar[pl.INT32] = 42");
}
TEST_F(IRStmtStrTest, TestSeqStmtsStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto y = std::make_shared<Var>("y", st, sp);
auto assign_x = std::make_shared<AssignStmt>(x, std::make_shared<ConstInt>(42, DataType::INT32, sp), sp);
auto assign_y = std::make_shared<AssignStmt>(y, std::make_shared<ConstInt>(0, DataType::INT32, sp), sp);
auto seq = std::make_shared<SeqStmts>(std::vector<StmtPtr>{assign_x, assign_y}, sp);
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const IRNode>(seq)),
"x: pl.Scalar[pl.INT32] = 42\ny: pl.Scalar[pl.INT32] = 0");
}
TEST_F(IRStmtStrTest, TestIfStmtStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto cond = std::make_shared<ConstBool>(true, sp);
auto assign = std::make_shared<AssignStmt>(x, std::make_shared<ConstInt>(42, DataType::INT32, sp), sp);
auto ifStmt = std::make_shared<IfStmt>(cond, assign, std::nullopt, std::vector<VarPtr>{}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(ifStmt)), "if True:\n x: pl.Scalar[pl.INT32] = 42");
}
TEST_F(IRStmtStrTest, TestIfElseStmtStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto y = std::make_shared<Var>("y", st, sp);
auto cond = std::make_shared<ConstBool>(true, sp);
auto thenBody = std::make_shared<AssignStmt>(x, std::make_shared<ConstInt>(42, DataType::INT32, sp), sp);
auto elseBody = std::make_shared<AssignStmt>(y, std::make_shared<ConstInt>(0, DataType::INT32, sp), sp);
auto ifElse = std::make_shared<IfStmt>(cond, thenBody, elseBody, std::vector<VarPtr>{}, sp);
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const IRNode>(ifElse)),
"if True:\n x: pl.Scalar[pl.INT32] = 42\nelse:\n y: pl.Scalar[pl.INT32] = 0");
}
TEST_F(IRStmtStrTest, TestForStmtStr)
{
auto i = std::make_shared<Var>("i", st, sp);
auto init = std::make_shared<ConstInt>(0, DataType::INT32, sp);
auto iterArg = std::make_shared<IterArg>("sum", st, init, sp);
auto retVar = std::make_shared<Var>("sum_out", st, sp);
auto body =
std::make_shared<YieldStmt>(std::vector<ExprPtr>{std::make_shared<ConstInt>(1, DataType::INT32, sp)}, sp);
auto forStmt = std::make_shared<ForStmt>(
i, std::make_shared<ConstInt>(0, DataType::INT32, sp), std::make_shared<ConstInt>(10, DataType::INT32, sp),
std::make_shared<ConstInt>(1, DataType::INT32, sp), std::vector<IterArgPtr>{iterArg}, body,
std::vector<VarPtr>{retVar}, sp);
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const IRNode>(forStmt)),
"for i, (sum,) in pl.range(0, 10, 1, init_values=(0,)):\n"
" sum_out: pl.Scalar[pl.INT32] = pl.yield_(1)");
}
TEST_F(IRStmtStrTest, TestWhileStmtStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto cond = std::make_shared<ConstBool>(true, sp);
auto assign = std::make_shared<AssignStmt>(x, std::make_shared<ConstInt>(42, DataType::INT32, sp), sp);
auto whileStmt = std::make_shared<WhileStmt>(cond, std::vector<IterArgPtr>{}, assign, std::vector<VarPtr>{}, sp);
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const IRNode>(whileStmt)), "while True:\n x: pl.Scalar[pl.INT32] = 42");
}
TEST_F(IRStmtStrTest, TestYieldStmtStr)
{
auto val = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto yield = std::make_shared<YieldStmt>(std::vector<ExprPtr>{val}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(yield)), "pl.yield_(42)");
auto emptyYield = std::make_shared<YieldStmt>(std::vector<ExprPtr>{}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(emptyYield)), "pl.yield_()");
}
TEST_F(IRStmtStrTest, TestReturnStmtStr)
{
auto val = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto ret = std::make_shared<ReturnStmt>(std::vector<ExprPtr>{val}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(ret)), "return 42");
auto emptyRet = std::make_shared<ReturnStmt>(std::vector<ExprPtr>{}, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(emptyRet)), "return");
}
TEST_F(IRStmtStrTest, TestBreakContinueStmtStr)
{
auto brk = std::make_shared<BreakStmt>(sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(brk)), "break");
auto cont = std::make_shared<ContinueStmt>(sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(cont)), "continue");
}
TEST_F(IRStmtStrTest, TestEvalStmtStr)
{
auto val = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto call = std::make_shared<Call>("some_op", std::vector<ExprPtr>{val}, sp);
auto eval = std::make_shared<EvalStmt>(call, sp);
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(eval)), "pl.eval(pl.call @some_op(42))");
}
TEST_F(IRStmtStrTest, TestFunctionStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto assign = std::make_shared<AssignStmt>(x, std::make_shared<ConstInt>(42, DataType::INT32, sp), sp);
auto func = std::make_shared<Function>("test_func", std::vector<VarPtr>{x}, std::vector<TypePtr>{st}, assign, sp);
ASSERT_EQ(
PythonPrint(std::static_pointer_cast<const IRNode>(func)),
"@pl.function\ndef test_func(x: pl.Scalar[pl.INT32]) -> pl.Scalar[pl.INT32]:\n"
" x: pl.Scalar[pl.INT32] = 42");
}
TEST_F(IRStmtStrTest, TestProgramStr)
{
auto x = std::make_shared<Var>("x", st, sp);
auto assign = std::make_shared<AssignStmt>(x, std::make_shared<ConstInt>(42, DataType::INT32, sp), sp);
auto func1 = std::make_shared<Function>("test_func", std::vector<VarPtr>{x}, std::vector<TypePtr>{st}, assign, sp);
auto func2 = std::make_shared<Function>("test_func2", std::vector<VarPtr>{x}, std::vector<TypePtr>{st}, assign, sp);
auto prog = std::make_shared<Program>(std::vector<FunctionPtr>{func1, func2}, "test_prog", sp);
std::string expected = "# ir.program: test_prog\n"
"@pl.function\n"
"def test_func(x: pl.Scalar[pl.INT32]) -> pl.Scalar[pl.INT32]:\n"
" x: pl.Scalar[pl.INT32] = 42\n"
"@pl.function\n"
"def test_func2(x: pl.Scalar[pl.INT32]) -> pl.Scalar[pl.INT32]:\n"
" x: pl.Scalar[pl.INT32] = 42";
ASSERT_EQ(PythonPrint(std::static_pointer_cast<const IRNode>(prog)), expected);
ASSERT_NE(prog->GetFunction("test_func"), nullptr);
}
class IRStructuralTest : public testing::Test {
protected:
Span sp = Span("test", 1, 1);
Span sp2 = Span("other", 5, 6);
};
TEST_F(IRStructuralTest, TestHashIdenticalNodes)
{
auto a = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
auto b = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
ASSERT_EQ(
structural_hash(std::static_pointer_cast<const IRNode>(a)),
structural_hash(std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestHashDifferentNodes)
{
auto a = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
auto b = std::make_shared<Var>("y", Scalar(DataType::INT32), sp);
ASSERT_NE(
structural_hash(std::static_pointer_cast<const IRNode>(a)),
structural_hash(std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestHashIgnoresSpan)
{
auto a = std::make_shared<Var>("x", Scalar(DataType::INT32), Span("file_a", 1, 1));
auto b = std::make_shared<Var>("x", Scalar(DataType::INT32), Span("file_b", 99, 99));
ASSERT_EQ(
structural_hash(std::static_pointer_cast<const IRNode>(a)),
structural_hash(std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestEqualIdenticalNodes)
{
auto a = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(42, DataType::INT32, sp);
ASSERT_TRUE(structural_equal(std::static_pointer_cast<const IRNode>(a), std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestEqualDifferentNodes)
{
auto a = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(99, DataType::INT32, sp);
ASSERT_FALSE(
structural_equal(std::static_pointer_cast<const IRNode>(a), std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestEqualIgnoresSpan)
{
auto a = std::make_shared<ConstInt>(42, DataType::INT32, Span("f1", 1, 1));
auto b = std::make_shared<ConstInt>(42, DataType::INT32, Span("f2", 99, 99));
ASSERT_TRUE(structural_equal(std::static_pointer_cast<const IRNode>(a), std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestEqualAutoMapping)
{
auto x = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
auto y = std::make_shared<Var>("y", Scalar(DataType::INT32), sp);
auto one = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto exprX = std::make_shared<Add>(x, one, DataType::INT32, sp);
auto exprY = std::make_shared<Add>(y, one, DataType::INT32, sp);
ASSERT_FALSE(structural_equal(
std::static_pointer_cast<const IRNode>(exprX), std::static_pointer_cast<const IRNode>(exprY), false));
ASSERT_TRUE(structural_equal(
std::static_pointer_cast<const IRNode>(exprX), std::static_pointer_cast<const IRNode>(exprY), true));
}
TEST_F(IRStructuralTest, TestHashAutoMapping)
{
auto x = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
auto y = std::make_shared<Var>("y", Scalar(DataType::INT32), sp);
ASSERT_NE(
structural_hash(std::static_pointer_cast<const IRNode>(x), false),
structural_hash(std::static_pointer_cast<const IRNode>(y), false));
ASSERT_EQ(
structural_hash(std::static_pointer_cast<const IRNode>(x), true),
structural_hash(std::static_pointer_cast<const IRNode>(y), true));
}
TEST_F(IRStructuralTest, TestAssertStructuralEqualPass)
{
auto a = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(42, DataType::INT32, sp);
ASSERT_NO_THROW(
assert_structural_equal(std::static_pointer_cast<const IRNode>(a), std::static_pointer_cast<const IRNode>(b)));
}
TEST_F(IRStructuralTest, TestAssertStructuralEqualFail)
{
auto a = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(99, DataType::INT32, sp);
ASSERT_THROW(
assert_structural_equal(std::static_pointer_cast<const IRNode>(a), std::static_pointer_cast<const IRNode>(b)),
ValueError);
}
TEST_F(IRStructuralTest, TestHashType)
{
auto t1 = std::make_shared<ScalarType>(DataType::INT32);
auto t2 = std::make_shared<ScalarType>(DataType::INT32);
auto t3 = std::make_shared<ScalarType>(DataType::FP32);
ASSERT_EQ(structural_hash(t1), structural_hash(t2));
ASSERT_NE(structural_hash(t1), structural_hash(t3));
}
TEST_F(IRStructuralTest, TestEqualType)
{
auto t1 = std::make_shared<ScalarType>(DataType::INT32);
auto t2 = std::make_shared<ScalarType>(DataType::INT32);
auto t3 = std::make_shared<ScalarType>(DataType::FP32);
ASSERT_TRUE(structural_equal(t1, t2));
ASSERT_FALSE(structural_equal(t1, t3));
}
TEST_F(IRStructuralTest, TestAssertStructuralEqualTypePass)
{
auto t1 = std::make_shared<ScalarType>(DataType::INT32);
auto t2 = std::make_shared<ScalarType>(DataType::INT32);
ASSERT_NO_THROW(assert_structural_equal(t1, t2));
}
TEST_F(IRStructuralTest, TestAssertStructuralEqualTypeFail)
{
auto t1 = std::make_shared<ScalarType>(DataType::INT32);
auto t2 = std::make_shared<ScalarType>(DataType::FP32);
ASSERT_THROW(assert_structural_equal(t1, t2), ValueError);
}
TEST_F(IRStructuralTest, TestEqualComplexExpressions)
{
auto a = std::make_shared<ConstInt>(1, DataType::INT32, sp);
auto b = std::make_shared<ConstInt>(2, DataType::INT32, sp);
auto add1 = std::make_shared<Add>(a, b, DataType::INT32, sp);
auto add2 = std::make_shared<Add>(a, b, DataType::INT32, sp);
auto sub1 = std::make_shared<Sub>(a, b, DataType::INT32, sp);
ASSERT_TRUE(
structural_equal(std::static_pointer_cast<const IRNode>(add1), std::static_pointer_cast<const IRNode>(add2)));
ASSERT_FALSE(
structural_equal(std::static_pointer_cast<const IRNode>(add1), std::static_pointer_cast<const IRNode>(sub1)));
}
TEST_F(IRStructuralTest, TestEqualNestedStatements)
{
auto x = std::make_shared<Var>("x", Scalar(DataType::INT32), sp);
auto val = std::make_shared<ConstInt>(42, DataType::INT32, sp);
auto assign = std::make_shared<AssignStmt>(x, val, sp);
auto seq1 = std::make_shared<SeqStmts>(std::vector<StmtPtr>{assign}, sp);
auto seq2 = std::make_shared<SeqStmts>(std::vector<StmtPtr>{assign}, sp);
ASSERT_TRUE(
structural_equal(std::static_pointer_cast<const IRNode>(seq1), std::static_pointer_cast<const IRNode>(seq2)));
}
}
}