"""Pytest 配置控制
"""
import os
from typing import List, Optional
import pytest
def duration_estimate(seconds: float):
"""
Decorator: annotate a test case with estimated duration (seconds).
This decorator marks test cases with their expected execution time,
allowing pytest to reorder tests for optimal parallel execution.
Args:
seconds: Estimated execution time in seconds
Example:
@duration_estimate(120)
def test_something():
...
"""
def decorator(func):
func.duration_estimate = seconds
return func
return decorator
def _set_process_desc(desc: str):
try:
import setproctitle
setproctitle.setproctitle(desc)
except ModuleNotFoundError:
pass
def pytest_addoption(parser: pytest.Parser):
"""向 pytest 注册自定义参数
:param parser: pytest.Parser 类型
"""
parser.addoption("--device", nargs="+", type=int,
help="Device ID, default 0")
parser.addoption(
"--test_case_info", action="store", default="", help="Test case info."
)
parser.addoption(
"--cards-per-case", type=int, default=1,
help="Number of cards required for each test case. Default is 1 (single-card cases)."
)
def _is_case_match_cards(item, target_cards) -> bool:
"""
判断测试用例是否匹配目标卡数
"""
cards_marker = item.get_closest_marker("world_size")
if cards_marker is None:
return True
required_cards = cards_marker.args
if not required_cards:
return True
if isinstance(required_cards[0], int):
return target_cards == required_cards[0]
return True
def pytest_configure_node(node):
"""pytest-xdist 回调函数, 在 pytest 主进程 fork 出 worker 进程之前被调用.
:param node: worker 节点
"""
device_id_lst: Optional[List[int]] = node.config.getoption("--device")
cards_per_case: int = node.config.getoption("--cards-per-case", 1)
if device_id_lst:
if cards_per_case > 1:
if len(device_id_lst) % cards_per_case != 0:
raise ValueError(
f"Cannot divide {len(device_id_lst)} devices into groups of {cards_per_case}"
)
num_groups = len(device_id_lst) // cards_per_case
worker_idx = int(str(node.gateway.id).lstrip("gw"))
if worker_idx >= num_groups:
node.gateway.id = "NoDevices"
node.gateway.remote_exec('import os; os.environ.pop("TILE_FWK_DEVICE_ID", None)')
node.gateway.remote_exec('import os; os.environ.pop("TILE_FWK_DEVICE_ID_LIST", None)')
return
start_idx = worker_idx * cards_per_case
end_idx = start_idx + cards_per_case
device_group = device_id_lst[start_idx:end_idx]
device_group_str = ",".join(map(str, device_group))
node.gateway.id = f"Devices[{device_group_str}]"
node.gateway.remote_exec(
f'import os; os.environ["TILE_FWK_DEVICE_ID_LIST"] = "{device_group_str}"'
)
else:
worker_idx = int(str(node.gateway.id).lstrip("gw"))
if worker_idx >= len(device_id_lst):
raise ValueError(f"WorkerIdx[{worker_idx}] out of DeviceIdLst{device_id_lst} range.")
device_id: int = device_id_lst[worker_idx]
node.gateway.id = f"Device[{device_id}]"
node.gateway.remote_exec(f'import os; os.environ["TILE_FWK_DEVICE_ID"] = "{device_id}"')
else:
node.gateway.remote_exec(f'import os; os.environ.pop("TILE_FWK_DEVICE_ID", None)')
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_protocol(item, nextitem):
device_list_str: Optional[str] = os.environ.get("TILE_FWK_DEVICE_ID_LIST", None)
if device_list_str is not None:
device_list = device_list_str.split(",")
_set_process_desc(f"Devices[{','.join(device_list)}]")
else:
device_id: Optional[str] = os.environ.get("TILE_FWK_DEVICE_ID", None)
if device_id is not None:
_set_process_desc(f"Device[{device_id}]")
return None
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
"""case 进程启动后被调用"""
device_list_str: Optional[str] = os.environ.get("TILE_FWK_DEVICE_ID_LIST", None)
if device_list_str is not None:
device_list = device_list_str.split(",")
else:
device_id: Optional[str] = os.environ.get("TILE_FWK_DEVICE_ID", None)
case_name: str = str(item.name)
if device_list_str is not None:
device_list = device_list_str.split(",")
_set_process_desc(f"Case(Devices[{','.join(device_list)}]::{case_name})")
else:
device_id: Optional[str] = os.environ.get("TILE_FWK_DEVICE_ID", None)
if device_id is not None:
_set_process_desc(f"Case(Device[{device_id}]::{case_name})")
return None
def _get_test_time_cost(item):
"""
获取测试用例的耗时信息
Args:
item: pytest 测试项
Returns:
int or None: 耗时秒数, 如果未标记则返回None
"""
if hasattr(item.function, 'duration_estimate'):
return item.function.duration_estimate
if hasattr(item, 'cls') and item.cls and hasattr(item.cls, 'duration_estimate'):
return item.cls.duration_estimate
time_marker = item.get_closest_marker("duration_estimate")
if time_marker and time_marker.args:
return time_marker.args[0]
return None
def _get_soc_version():
"""
从torch_npu获取soc version
"""
try:
import torch_npu
soc_version = torch_npu.npu.get_soc_version()
return soc_version
except Exception as e:
pytest.exit(f"Error: Failed to get soc version, error info: {str(e)}", returncode=1)
return None
def _is_case_match_soc(item, target_soc):
"""
判断测试用例是否匹配目标soc版本
"""
soc_marker = item.get_closest_marker("soc")
if soc_marker is None:
supported_socs = ["910"]
else:
supported_socs = soc_marker.args
if isinstance(supported_socs[0], str):
supported_socs = [soc.strip() for soc in supported_socs]
elif isinstance(supported_socs[0], list):
supported_socs = [soc.strip() for soc in supported_socs[0]]
if target_soc == 260:
target_tag = "950"
else:
target_tag = "910"
return target_tag in supported_socs
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(config, items):
"""
在所有conftest.py作用域处理完成后进行全局重排序
"""
if not items:
return
first_item = items[0]
item_path = str(first_item.fspath)
has_ut = "ut" in item_path.lower()
if has_ut:
filtered_items = items
else:
target_soc = _get_soc_version()
filtered_items = [item for item in items if _is_case_match_soc(item, target_soc)]
cards_per_case = config.getoption("--cards-per-case", 1)
card_filtered_items = [item for item in filtered_items
if _is_case_match_cards(item, cards_per_case)]
timed_tests = []
untimed_tests = []
for item in card_filtered_items:
time_cost = _get_test_time_cost(item)
if time_cost is not None:
timed_tests.append((item, time_cost))
else:
untimed_tests.append(item)
timed_tests.sort(key=lambda x: x[1], reverse=True)
reordered_items = [item for item, _ in timed_tests] + untimed_tests
items[:] = reordered_items