import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
from detectron2.utils import comm
__all__ = ["launch"]
def _find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
"""
Launch multi-gpu or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine.
Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of GPUs per machine
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
if dist_url == "auto":
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(_find_free_port())
os.environ["WORLD_SIZE"] = str(world_size)
mp.spawn(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
daemon=False,
)
else:
argsd = dict(zip(args[0].opts[0::2], args[0].opts[1::2]))
os.environ['KERNEL_NAME_ID'] = \
argsd['MODEL.DEVICE'].split(':')[-1]
torch.npu.set_device(argsd['MODEL.DEVICE'])
main_func(*args)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
args
):
assert torch.npu.is_available(), \
"npu is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
os.environ["RANK"] = str(local_rank)
os.environ['KERNEL_NAME_ID'] = str(local_rank)
try:
dist.init_process_group(
backend="hccl",
world_size=world_size, rank=global_rank
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
assert num_gpus_per_machine <= torch.npu.device_count()
if args[0].device_ids:
device = 'npu:{}'.format(args[0].device_ids[local_rank])
else:
device = 'npu:{}'.format(local_rank)
torch.npu.set_device(device)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
main_func(*args)