"""Tests for op_replay/common.py — pure functions (no NPU needed)."""
import argparse
import unittest
from tools.perf_data_collection.op_replay.common import (
SUPPORTED_DEVICES,
DEFAULT_DEVICE,
DEFAULT_REPLAY_REPEAT_COUNT,
SUPPORTED_UPDATE_MODES,
MICROBENCH_DURATION,
check_version,
normalize_device_name,
normalize_vllm_ascend_version,
parse_list_field,
split_metadata_field,
parse_shape,
parse_shape_or_none,
normalize_dtype_name,
normalize_op_name,
expand_fractal_nz_shape,
normalize_shape,
build_version_dir_name,
is_version_dir_name,
_normalize_stack_component,
INVALID_REPLAY_ROWS,
)
class TestConstants(unittest.TestCase):
def test_supported_devices(self):
self.assertIn(DEFAULT_DEVICE, SUPPORTED_DEVICES)
self.assertGreater(len(SUPPORTED_DEVICES), 3)
def test_default_replay_repeat_count(self):
self.assertGreater(DEFAULT_REPLAY_REPEAT_COUNT, 0)
def test_supported_update_modes(self):
self.assertIn("all", SUPPORTED_UPDATE_MODES)
self.assertIn("missing-only", SUPPORTED_UPDATE_MODES)
def test_microbench_duration_column(self):
self.assertEqual(MICROBENCH_DURATION, "Average Duration(us)")
def test_invalid_replay_rows_is_list(self):
self.assertIsInstance(INVALID_REPLAY_ROWS, list)
class TestCheckVersion(unittest.TestCase):
def test_valid_simple(self):
self.assertEqual(check_version("0.9.2"), "0.9.2")
def test_valid_with_v(self):
self.assertIsNotNone(check_version("v0.13.0"))
def test_valid_with_underscore(self):
self.assertIsNotNone(check_version("vllm0.18.0_torch2.9.0_cann8.5"))
def test_invalid_raises(self):
with self.assertRaises(argparse.ArgumentTypeError):
check_version("bad version with spaces")
class TestNormalizeDeviceName(unittest.TestCase):
def test_strips_whitespace(self):
self.assertEqual(normalize_device_name(" ATLAS_800 "), "ATLAS_800")
class TestNormalizeVllmAscendVersion(unittest.TestCase):
def test_strips_whitespace(self):
self.assertEqual(normalize_vllm_ascend_version(" 0.13.0 "), "0.13.0")
class TestNormalizeStackComponent(unittest.TestCase):
def test_vllm_prefix(self):
self.assertEqual(_normalize_stack_component("vllm", "vllm0.18.0"), "0.18.0")
def test_v_prefix(self):
self.assertEqual(_normalize_stack_component("vllm", "v0.18.0"), "0.18.0")
def test_torch_prefix(self):
result = _normalize_stack_component("torch", "torch2.9.0+cpu")
self.assertIn("2.9.0", result)
def test_cann_prefix(self):
self.assertEqual(_normalize_stack_component("cann", "cann8.5"), "8.5")
class TestBuildVersionDirName(unittest.TestCase):
def test_standard(self):
result = build_version_dir_name(
vllm_ascend_version="0.18.0",
torch_version="2.9.0",
cann_version="8.5",
)
self.assertEqual(result, "vllm0.18.0_torch2.9.0_cann8.5")
def test_with_prefixes(self):
result = build_version_dir_name(
vllm_ascend_version="v0.18.0",
torch_version="torch2.9.0",
cann_version="cann8.5",
)
self.assertEqual(result, "vllm0.18.0_torch2.9.0_cann8.5")
class TestIsVersionDirName(unittest.TestCase):
def test_valid(self):
self.assertTrue(is_version_dir_name("vllm0.18.0_torch2.9.0_cann8.5"))
def test_invalid(self):
self.assertFalse(is_version_dir_name("not_a_version"))
class TestParseListField(unittest.TestCase):
def test_semicolon(self):
self.assertEqual(parse_list_field("a;b;c"), ["a", "b", "c"])
def test_quoted(self):
self.assertEqual(parse_list_field('"a;b;c"'), ["a", "b", "c"])
def test_empty(self):
self.assertEqual(parse_list_field(""), [])
class TestSplitMetadataField(unittest.TestCase):
def test_semicolon(self):
self.assertEqual(split_metadata_field("a;b"), ["a", "b"])
def test_quoted(self):
self.assertEqual(split_metadata_field('"a;b"'), ["a", "b"])
def test_empty(self):
self.assertEqual(split_metadata_field(""), [""])
class TestParseShape(unittest.TestCase):
def test_simple(self):
self.assertEqual(parse_shape("128,5120"), (128, 5120))
def test_single_dim(self):
self.assertEqual(parse_shape("4096"), (4096,))
class TestParseShapeOrNone(unittest.TestCase):
def test_valid(self):
self.assertEqual(parse_shape_or_none("128,5120"), (128, 5120))
def test_empty_returns_none(self):
self.assertIsNone(parse_shape_or_none(" "))
class TestNormalizeDtypeName(unittest.TestCase):
def test_with_prefix(self):
self.assertEqual(normalize_dtype_name("DT_BF16"), "DT_BF16")
def test_without_prefix(self):
self.assertEqual(normalize_dtype_name("BF16"), "DT_BF16")
def test_empty_returns_undefined(self):
self.assertEqual(normalize_dtype_name(""), "DT_UNDEFINED")
class TestNormalizeOpName(unittest.TestCase):
def test_removes_run_py(self):
self.assertEqual(normalize_op_name("MatMulV2_run.py"), "MatMulV2")
def test_removes_run(self):
self.assertEqual(normalize_op_name("PadV3_run"), "PadV3")
def test_removes_csv(self):
self.assertEqual(normalize_op_name("SoftmaxV2.csv"), "SoftmaxV2")
def test_passthrough(self):
self.assertEqual(normalize_op_name("MatMulV2"), "MatMulV2")
class TestExpandFractalNzShape(unittest.TestCase):
def test_valid(self):
result = expand_fractal_nz_shape((2, 3, 4, 5))
self.assertEqual(result, (12, 10))
def test_invalid_dims_raises(self):
with self.assertRaises(ValueError):
expand_fractal_nz_shape((2, 3))
class TestNormalizeShape(unittest.TestCase):
def test_regular_passthrough(self):
self.assertEqual(normalize_shape((128, 5120), "ND"), (128, 5120))
def test_fractal_nz_expands(self):
result = normalize_shape((2, 3, 4, 5), "FRACTAL_NZ")
self.assertEqual(result, (12, 10))
if __name__ == "__main__":
unittest.main()