import unittest
from msprobe.visualization.graph.node_op import NodeOp
class TestNodeOp(unittest.TestCase):
def test_get_node_op_valid(self):
node_name = "ModuleTest"
self.assertEqual(NodeOp.get_node_op(node_name), NodeOp.module)
def test_get_node_op_invalid(self):
node_name = "InvalidNodeName"
self.assertEqual(NodeOp.get_node_op(node_name), NodeOp.module)
def test_get_node_op_all(self):
test_cases = [
("ModuleTest", NodeOp.module),
("TensorTest", NodeOp.function_api),
("TorchTest", NodeOp.function_api),
("FunctionalTest", NodeOp.function_api),
("NPUTest", NodeOp.function_api),
("VFTest", NodeOp.function_api),
("DistributedTest", NodeOp.function_api),
("AtenTest", NodeOp.function_api)
]
for node_name, expected_op in test_cases:
self.assertEqual(NodeOp.get_node_op(node_name), expected_op)