from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from tools.perf_data_collection.op_replay import common
if TYPE_CHECKING:
from pathlib import Path
def test_get_target_data_dir_prefers_explicit_database_path(tmp_path: Path):
database_path = tmp_path / "custom_db"
assert common.get_target_data_dir(database_path=database_path) == database_path
def test_get_target_data_dir_preserves_full_version_dir_name():
target_dir = common.get_target_data_dir(
device="ATLAS_800_A3_752T_128G_DIE",
vllm_ascend_version="vllm0.18.0_torch2.9.0_cann8.5",
)
assert target_dir == (
common.DATA_DIR / "ATLAS_800_A3_752T_128G_DIE" / "vllm_ascend" / "vllm0.18.0_torch2.9.0_cann8.5"
)
def test_get_target_data_dir_builds_version_dir_from_components():
target_dir = common.get_target_data_dir(
device="ATLAS_800_A3_752T_128G_DIE",
vllm_ascend_version="0.18.0",
torch_version="2.9.0",
cann_version="8.5",
)
assert target_dir == (
common.DATA_DIR / "ATLAS_800_A3_752T_128G_DIE" / "vllm_ascend" / "vllm0.18.0_torch2.9.0_cann8.5"
)
def test_get_target_data_dir_raises_when_versions_cannot_be_detected(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setattr(
common,
"detect_runtime_stack_versions",
lambda: (None, None, None),
)
with pytest.raises(RuntimeError, match="Specify --database-path"):
common.get_target_data_dir(device="ATLAS_800_A3_752T_128G_DIE")
def test_detect_cann_version_reads_ascend_toolkit_install_info(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
for env_name in (
"ASCEND_HOME_PATH",
"ASCEND_TOOLKIT_HOME",
"ASCEND_TOOLKIT_HOME_PATH",
"ASCEND_INSTALL_PATH",
):
monkeypatch.delenv(env_name, raising=False)
cann_root = tmp_path / "Ascend" / "cann" / "arm64-linux"
cann_root.mkdir(parents=True)
(cann_root / "ascend_toolkit_install.info").write_text(
"package_name=Ascend-cann-toolkit\nversion=8.5.0\ninnerversion=V100R001C25SPC001B232\n",
encoding="utf-8",
)
def _raise_module_not_found(name: str):
raise ModuleNotFoundError(name)
monkeypatch.setattr(common, "import_module", _raise_module_not_found)
monkeypatch.setattr(common.Path, "home", staticmethod(lambda: tmp_path))
assert common.detect_cann_version() == "8.5.0"