from __future__ import annotations
import os
import queue as py_queue
import shutil
import tempfile
import threading
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple
import torch.multiprocessing as mp
def _dts_prepare_spawn() -> None:
try:
mp.set_start_method("spawn", force=True)
except RuntimeError:
pass
def spawn_queue():
"""父进程侧 Queue:必须与 ``torch.multiprocessing.spawn`` 同属 spawn 上下文(Linux 默认 fork 会触发 SemLock 错误)。"""
return mp.get_context("spawn").Queue()
def _dts_spawn_entry(
rank: int,
world_size: int,
backend: str,
init_file_abs: str,
user_fn: Callable[..., Any],
user_args: Tuple[Any, ...],
user_kwargs: Optional[dict],
) -> None:
import torch.distributed as dist
init_url = "file://" + init_file_abs.replace("\\", "/")
dist.init_process_group(
backend=backend,
init_method=init_url,
rank=rank,
world_size=world_size,
)
try:
user_fn(rank, world_size, *user_args, **(user_kwargs or {}))
if dist.is_initialized() and world_size > 1:
dist.barrier()
finally:
if dist.is_initialized():
dist.destroy_process_group()
def run_distributed_spawn(
world_size: int,
fn: Callable[..., Any],
fn_args: Tuple[Any, ...] = (),
fn_kwargs: Optional[dict] = None,
*,
backend: str = "gloo",
init_dir_prefix: str = "dts_dist_spawn_",
) -> None:
"""使用 ``torch.multiprocessing.spawn`` 启动多进程,并在各进程内初始化进程组后调用 ``fn``。
调用约定:``fn(rank, world_size, *fn_args, **fn_kwargs)``(``rank`` / ``world_size`` 与 PyTorch spawn 一致)。
临时目录仅用于 ``file://`` store;全部子进程退出后删除。
"""
fn_kwargs = fn_kwargs or {}
init_dir = tempfile.mkdtemp(prefix=init_dir_prefix)
init_file_abs = os.path.join(init_dir, "store")
open(init_file_abs, "w").close()
abs_path = os.path.abspath(init_file_abs)
try:
_dts_prepare_spawn()
mp.spawn(
_dts_spawn_entry,
args=(world_size, backend, abs_path, fn, fn_args, fn_kwargs),
nprocs=world_size,
join=True,
)
finally:
shutil.rmtree(init_dir, ignore_errors=True)
class QueueResultFuture:
"""父进程侧 queue 结果收集 future。"""
def __init__(self, results_queue, expected_results: int, poll_interval_s: float = 0.2):
self._queue = results_queue
self._expected = int(expected_results)
self._poll_s = float(poll_interval_s)
self._results: List[Any] = []
self._lock = threading.Lock()
self._done = threading.Event()
self._stop = threading.Event()
self._thread = threading.Thread(target=self._consume, name="dts-results-consumer", daemon=True)
self._thread.start()
def _consume(self) -> None:
while not self._stop.is_set():
try:
item = self._queue.get(timeout=self._poll_s)
except py_queue.Empty:
continue
with self._lock:
self._results.append(item)
if len(self._results) >= self._expected:
self._done.set()
self._stop.set()
return
def result(self, timeout_s: float = 120.0) -> List[Any]:
if self._expected <= 0:
return []
ok = self._done.wait(timeout=max(float(timeout_s), 1.0))
self._stop.set()
self._thread.join(timeout=max(1.0, self._poll_s * 5))
if not ok:
raise TimeoutError(
f"QueueResultFuture timeout after {timeout_s}s "
f"(collected={len(self._results)}/{self._expected})."
)
with self._lock:
return list(self._results[: self._expected])
def start_queue_result_collector(results_queue, expected_results: int, *, poll_interval_s: float = 0.2) -> QueueResultFuture:
"""启动父进程消费线程并返回结果 future。"""
return QueueResultFuture(results_queue, expected_results=expected_results, poll_interval_s=poll_interval_s)
def distributed_test(world_size: int = 2, backend: str = "gloo"):
"""pytest 装饰器:被装饰函数签名为 ``def test_xxx(rank, world_size, ...):``。
注意:需要 ``pytest`` 运行时导入本模块;与 ``unittest.TestCase`` 混用时请直接调用
:func:`run_distributed_spawn`。
"""
def decorator(func: Callable[..., Any]):
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> None:
def _target(rank: int, wsize: int) -> None:
func(rank, wsize, *args, **kwargs)
run_distributed_spawn(world_size, _target, backend=backend)
return wrapper
return decorator