import os
import unittest
from msprobe.visualization.utils import (load_json_file, load_data_json_file, str2float, check_directory_content,
GraphConst, SerializableArgs)
class TestMappingConfig(unittest.TestCase):
def setUp(self):
self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml")
self.input = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
def test_load_json_file(self):
result = load_json_file(self.yaml_path)
self.assertEqual(result, {})
def test_load_data_json_file(self):
result = load_data_json_file(self.yaml_path)
self.assertEqual(result, {})
def test_str2float(self):
result = str2float('23.4%')
self.assertAlmostEqual(result, 0.234)
result = str2float('2.3.4%')
self.assertAlmostEqual(result, 0)
def test_check_directory_content(self):
input_type = check_directory_content(self.input)
self.assertEqual(input_type, GraphConst.STEPS)
input_type = check_directory_content(os.path.join(self.input, "step0"))
self.assertEqual(input_type, GraphConst.RANKS)
with self.assertRaises(ValueError):
check_directory_content(os.path.join(self.input, "step1"))
input_type = check_directory_content(os.path.join(self.input, "step0", "rank0"))
self.assertEqual(input_type, GraphConst.FILES)
def test_serializable_args(self):
class TmpArgs:
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
input_args1 = TmpArgs('a', 123, [1, 2, 3])
serializable_args1 = SerializableArgs(input_args1)
self.assertEqual(serializable_args1.__dict__, input_args1.__dict__)
input_args2 = TmpArgs('a', 123, lambda x: print(x))
serializable_args2 = SerializableArgs(input_args2)
self.assertNotEqual(serializable_args2.__dict__, input_args2.__dict__)
if __name__ == '__main__':
unittest.main()