"""STest 分布式 STest 并性执行.
"""
import argparse
import datetime
import logging
import os
import subprocess
from typing import Any, List, Tuple, Dict
from stest_accelerate import STestAccelerate
class DistributedSTestAccelerate(STestAccelerate):
"""分布式STest执行加速
支持多卡并行执行 通过设备分组实现分布式测试.
继承自STestAccelerate, 使用父类的device_list参数, 按照rank_size进行设备分组.
"""
def __init__(self, args: argparse.Namespace):
super().__init__(args, scene_mark="Distributed STest", cntr_name="Device Group")
device_list = self._init_get_device_list(args)
self.rank_size = args.rank_size
self.device_groups = self._group_devices_by_rank_size(device_list, self.rank_size)
@staticmethod
def reg_args(parser: argparse.ArgumentParser) -> None:
"""注册分布式STest参数
先调用父类(STestAccelerate)的参数注册,再添加分布式特有参数
"""
STestAccelerate.reg_args(parser)
parser.add_argument("--rank_size", type=int, required=True,
help="Number of devices per test group")
@staticmethod
def main() -> bool:
"""分布式主处理流程"""
parser = argparse.ArgumentParser(
description="Distributed STest Execute Accelerate",
epilog="Best Regards!",
)
DistributedSTestAccelerate.reg_args(parser=parser)
args = parser.parse_args()
ctrl = DistributedSTestAccelerate(args=args)
ctrl.prepare()
ctrl.process()
return ctrl.post()
@staticmethod
def set_distributed_device_envs(p: Any) -> Dict[str, str]:
"""设置分布式设备环境变量
多卡用例通过TILE_FWK_DEVICE_ID_LIST环境变量指定使用的设备组
"""
custom_data = p.custom
device_group = custom_data["device_group"]
device_list_str = ",".join(str(device_id) for device_id in device_group)
return {
"TILE_FWK_DEVICE_ID_LIST": device_list_str,
}
@staticmethod
def _group_devices_by_rank_size(devices: List[int], rank_size: int) -> List[List[int]]:
"""按照rank_size对设备进行顺序分组
:param devices: 设备列表
:param rank_size: 每组设备数量
:return: 设备分组列表
"""
if len(devices) < rank_size:
raise ValueError(f"Available devices ({len(devices)}) are less than required rank_size ({rank_size})")
device_groups = []
sorted_devices = sorted(devices)
for i in range(0, len(sorted_devices), rank_size):
group = sorted_devices[i:i + rank_size]
if len(group) == rank_size:
device_groups.append(group)
return device_groups
def _prepare_get_params(self) -> List[STestAccelerate.ExecParam]:
params = []
for group_id, device_group in enumerate(self.device_groups):
param = STestAccelerate.ExecParam(
cntr_id=group_id,
envs_func=DistributedSTestAccelerate.set_distributed_device_envs,
custom={
"device_group": device_group,
"rank_size": self.rank_size,
"group_id": group_id,
},
)
params.append(param)
return params
def _execute_case(self, ctx: STestAccelerate.CaseContext,
param: STestAccelerate.ExecParam,
case_name: str) -> Tuple[subprocess.CompletedProcess, str, datetime.timedelta]:
"""多卡模式执行 - 重写父类方法"""
rank_size = param.custom.get("rank_size")
if rank_size is None:
raise ValueError("Missing rank_size in custom config, run distribute case failed.")
if rank_size <= 1:
raise ValueError("Distribute case rank size need greater than 1, run distribute case failed.")
device_group = param.custom.get("device_group", [param.cntr_id])
return self._run_multi_device_case(ctx, device_group, rank_size, case_name)
def _run_multi_device_case(self, ctx: STestAccelerate.CaseContext,
device_group: List[int],
rank_size: int,
case_name: str) -> Tuple[subprocess.CompletedProcess, str, datetime.timedelta]:
"""执行多卡分布式测试用例
:param ctx: Case上下文
:param device_group: 设备组列表
:param rank_size: 设备组大小
:return: 执行结果,命令行,时间
"""
env_vars = os.environ.copy()
env_vars.update(self.exe.envs)
env_vars.update(ctx.exec_param.get_envs())
command = [
"mpirun", "-n", str(rank_size),
str(self.exe.file),
f"--gtest_filter={case_name}",
]
device_info = f"DeviceGroup{device_group}"
logging.info(f"Executing {case_name} on {device_info} with rank_size {rank_size}.")
ts = datetime.datetime.now(tz=datetime.timezone.utc)
completed_process = subprocess.run(
command,
env=env_vars,
capture_output=True,
text=True,
)
return completed_process, ' '.join(command), datetime.datetime.now(tz=datetime.timezone.utc) - ts
if __name__ == "__main__":
logging.basicConfig(
format='%(asctime)s - %(filename)s:%(lineno)d - PID[%(process)d] - %(levelname)s: %(message)s',
level=logging.INFO,
handlers=[logging.StreamHandler()],
)
exit(0 if DistributedSTestAccelerate.main() else 1)