"""STest 用例并行执行.
"""
import argparse
import logging
import os
import re
import subprocess
from typing import List, Any, Optional, Dict
from accelerate.tests_accelerate import TestsAccelerate
class STestAccelerate(TestsAccelerate):
"""STest 执行加速
通过多进程并行执行, 以提升 STest 执行效率.
"""
def __init__(self, args: argparse.Namespace, scene_mark: str = "STest", cntr_name: str = "Device"):
"""
:param args: 命令行参数
:param scene_mark: 场景标识
:param cntr_name: 容器名称
"""
binary_path = None
if hasattr(args, 'target') and args.target and len(args.target) > 0:
binary_path = args.target[0]
elif hasattr(args, 'exe') and hasattr(args.exe, 'file'):
binary_path = args.exe.file
if args.cases and binary_path:
reordered_cases = self._reorder_cases_with_binary_meta(args.cases, binary_path)
args.cases = reordered_cases
elif args.cases and not binary_path:
logging.warning("Binary path not found, skipping meta-based reordering")
super().__init__(args, scene_mark=scene_mark, cntr_name=cntr_name)
self.device_list: List[int] = self._init_get_device_list(args=args)
@staticmethod
def reg_args(parser: argparse.ArgumentParser) -> None:
"""注册STest加速器参数
先调用父类(TestsAccelerate)的参数注册, 再添加STest特有参数
"""
TestsAccelerate.reg_args(parser)
parser.add_argument("-d", "--device", nargs="?", type=int, action="append",
help="Specific parallel accelerate device, "
"If this parameter is not specified, 0 device will be used by default.")
@staticmethod
def main() -> bool:
"""主处理流程
"""
parser = argparse.ArgumentParser(description=f"STest Execute Accelerate", epilog="Best Regards!")
STestAccelerate.reg_args(parser=parser)
args = parser.parse_args()
ctrl = STestAccelerate(args=args)
ctrl.prepare()
ctrl.process()
return ctrl.post()
@staticmethod
def get_case_exec_update_envs(p: Any) -> Optional[Dict[str, str]]:
self = p
return {"TILE_FWK_DEVICE_ID": f"{self.cntr_id}"}
@staticmethod
def _init_get_device_list(args) -> List[int]:
device_list = [0]
if args.device is not None:
device_list = [int(d) for d in list(set(args.device)) if d is not None and str(d) != ""]
return device_list
@staticmethod
def _get_test_costs(binary: str) -> Dict[str, float]:
"""
获取所有带耗时信息的测试用例(通过自定义参数--gtest_list_tests_with_meta)
返回格式: { "TestCaseName.TestName": cost_seconds, ... }
"""
cost_map = {}
if not binary or not os.path.exists(binary):
logging.warning("Binary file not found: %s", binary)
return cost_map
try:
result = subprocess.run(
[binary, '--gtest_list_tests_with_meta'],
capture_output=True,
text=True,
encoding='utf-8'
)
if result.returncode != 0:
logging.warning("Failed to get test costs from binary %s: %s", binary, result.stderr)
return cost_map
pattern = re.compile(r'^([\w.]+)\|(\d+\.?\d*)$', re.MULTILINE)
matches = pattern.findall(result.stdout)
for test_name, cost_str in matches:
try:
cost_map[test_name.strip()] = float(cost_str.strip())
except ValueError:
continue
except (subprocess.SubprocessError, FileNotFoundError) as e:
logging.warning("Failed to run binary %s to get meta info: %s", binary, e)
return cost_map
@staticmethod
def _reorder_cases_with_binary_meta(cases: List[str], binary: str) -> List[str]:
"""
基于 binary meta 耗时对 stest 用例进行重排:
- 有耗时信息的用例排前面,按耗时降序
- 无耗时信息的用例排后面,保持原有顺序
"""
if not cases or not binary:
return cases
cost_map = STestAccelerate._get_test_costs(binary)
if not cost_map:
logging.debug("No cost meta found for %s, keep original cases order", binary)
return cases
cost_cases: List[str] = []
no_cost_cases: List[str] = []
for cs in cases:
if cs in cost_map:
cost_cases.append(cs)
else:
no_cost_cases.append(cs)
cost_cases_sorted = sorted(cost_cases, key=lambda x: cost_map[x], reverse=True)
logging.info(
"STest(meta): Found %d tests with cost info, %d tests without.",
len(cost_cases_sorted), len(no_cost_cases)
)
if cost_cases_sorted:
logging.info("STest(meta): First few cost-aware tests(desc): %s", cost_cases_sorted[:5])
return cost_cases_sorted + no_cost_cases
def _prepare_get_params(self) -> List[TestsAccelerate.ExecParam]:
params = []
for _id in self.device_list:
p = TestsAccelerate.ExecParam(cntr_id=_id, envs_func=STestAccelerate.get_case_exec_update_envs)
params.append(p)
return params
if __name__ == "__main__":
logging.basicConfig(
format='%(asctime)s - %(filename)s:%(lineno)d - PID[%(process)d] - %(levelname)s: %(message)s',
level=logging.INFO,
handlers=[
logging.StreamHandler()
]
)
exit(0 if STestAccelerate.main() else 1)