from __future__ import annotations
import argparse
import logging
from types import SimpleNamespace
from unittest import TestCase
import pytest
from cli.utils import (
check_device_targets,
check_non_negative_integer,
check_prefix_cache_hit_rate,
check_positive_integer,
check_string_valid,
get_common_argparser,
parse_int_range,
)
from tensor_cast.device import DeviceProfile
class _DummyGrid:
def __init__(self, size: int) -> None:
self._size = size
def nelement(self) -> int:
return self._size
class _DummyProfile:
def __init__(self, size: int) -> None:
self.comm_grid = SimpleNamespace(grid=_DummyGrid(size))
class TestCliUtils(TestCase):
def test_common_argparser_reserved_memory_default_is_zero(self):
parser = get_common_argparser()
args = parser.parse_args(["Qwen/Qwen3-32B"])
self.assertEqual(args.reserved_memory_gb, 0.0)
def test_common_argparser_reserved_memory_default_can_be_overridden(self):
parser = get_common_argparser(reserved_memory_gb_default=10.0)
args = parser.parse_args(["Qwen/Qwen3-32B"])
self.assertEqual(args.reserved_memory_gb, 10.0)
@pytest.fixture
def device_profiles(monkeypatch: pytest.MonkeyPatch) -> dict[str, _DummyProfile]:
profiles = {
"TEST_DEVICE": _DummyProfile(4),
"NPU_A": _DummyProfile(8),
}
monkeypatch.setattr(DeviceProfile, "all_device_profiles", profiles, raising=False)
return profiles
def test_common_argparser_parses_device_and_num_devices(device_profiles: dict[str, _DummyProfile]) -> None:
parser = get_common_argparser(reserved_memory_gb_default=10.0)
args = parser.parse_args(
[
"Qwen/Qwen3-32B",
"--device",
"TEST_DEVICE",
"--num-devices",
"2",
"--log-level",
"debug",
]
)
assert args.device == "TEST_DEVICE"
assert args.num_devices == 2
assert args.log_level == "debug"
assert args.reserved_memory_gb == 10.0
@pytest.mark.parametrize("value", ["1", "100", 5])
def test_check_positive_integer_accepts_valid_values(value: int | str) -> None:
assert check_positive_integer(value) == int(value)
@pytest.mark.parametrize("value", ["abc", "0", "-1"])
def test_check_positive_integer_rejects_invalid_values(value: str) -> None:
with pytest.raises(argparse.ArgumentTypeError):
check_positive_integer(value)
@pytest.mark.parametrize("value", ["0", "1", "42"])
def test_check_non_negative_integer_accepts_valid_values(value: str) -> None:
assert check_non_negative_integer(value) == int(value)
@pytest.mark.parametrize("value", ["abc", "-1"])
def test_check_non_negative_integer_rejects_invalid_values(value: str) -> None:
with pytest.raises(argparse.ArgumentTypeError):
check_non_negative_integer(value)
@pytest.mark.parametrize("value", ["0", "0.5", "0.999999"])
def test_check_prefix_cache_hit_rate_accepts_valid_values(value: str) -> None:
assert check_prefix_cache_hit_rate(value) == pytest.approx(float(value))
@pytest.mark.parametrize("value", ["1", "-0.1", "abc"])
def test_check_prefix_cache_hit_rate_rejects_invalid_values(value: str) -> None:
with pytest.raises(argparse.ArgumentTypeError):
check_prefix_cache_hit_rate(value)
@pytest.mark.parametrize(
("value", "expected"),
[
("1,2", (1, 2)),
(" 0 , 4 ", (0, 4)),
],
)
def test_parse_int_range_accepts_valid_values(value: str, expected: tuple[int, int]) -> None:
assert parse_int_range(value, "--range") == expected
@pytest.mark.parametrize(
"value",
[
"1",
"1,",
",2",
"a,b",
"-1,2",
"3,2",
],
)
def test_parse_int_range_rejects_invalid_values(value: str) -> None:
with pytest.raises(ValueError):
parse_int_range(value, "--range")
@pytest.mark.parametrize("value", ["valid_string123/test-path.file", "abc/DEF-123.txt"])
def test_check_string_valid_accepts_valid_values(value: str) -> None:
assert check_string_valid(value, max_len=100) == value
@pytest.mark.parametrize("value", ["invalid value", "bad#value"])
def test_check_string_valid_rejects_invalid_characters(value: str) -> None:
with pytest.raises(argparse.ArgumentTypeError):
check_string_valid(value)
def test_check_string_valid_rejects_overlong_value() -> None:
with pytest.raises(argparse.ArgumentTypeError):
check_string_valid("x" * 257)
def test_check_device_targets_returns_default_and_dedupes(device_profiles: dict[str, _DummyProfile]) -> None:
args = argparse.Namespace(device=None, num_devices=2)
logger = logging.getLogger("cli.utils.test")
result = check_device_targets(args, logger)
assert result == ["TEST_DEVICE"]
assert args.device == ["TEST_DEVICE"]
args = argparse.Namespace(device=["NPU_A", "NPU_A", "TEST_DEVICE"], num_devices=2)
result = check_device_targets(args, logger)
assert result == ["NPU_A", "TEST_DEVICE"]
def test_check_device_targets_rejects_missing_profiles(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(DeviceProfile, "all_device_profiles", {}, raising=False)
args = argparse.Namespace(device=["TEST_DEVICE"], num_devices=1)
assert check_device_targets(args, logging.getLogger("cli.utils.test")) is None
@pytest.mark.parametrize(
"device_values",
[
[""],
["unknown-device"],
],
)
def test_check_device_targets_rejects_blank_and_unknown_devices(
device_profiles: dict[str, _DummyProfile],
device_values: list[str],
) -> None:
args = argparse.Namespace(device=device_values, num_devices=1)
assert check_device_targets(args, logging.getLogger("cli.utils.test")) is None
def test_check_device_targets_rejects_undersized_comm_grid(device_profiles: dict[str, _DummyProfile]) -> None:
args = argparse.Namespace(device=["TEST_DEVICE"], num_devices=8)
assert check_device_targets(args, logging.getLogger("cli.utils.test")) is None