__all__ = ["parse_args"]
from torch.distributed import run as torch_run
from torch.distributed.argparse_util import check_env, env
from torch.distributed.run import get_args_parser
from torch.distributed.elastic.multiprocessing.errors import record
import torch_npu
def parse_args(args):
parser = get_args_parser()
parser.add_argument(
"--enable_tiered_parallel_tcpstore",
"--enable_tiered_parallel_tcpstore",
action=env,
type=str,
default="false",
help="Turn parallel tcpstore tiered optimization, if true, The agent adds a proxy role,"
"the worker on this node will connect to the server through the proxy.",
)
return parser.parse_args(args)
@record
def _main(args=None):
args = parse_args(args)
args.rdzv_backend = 'parallel'
if not args.rdzv_endpoint:
args.rdzv_endpoint = f"{args.master_addr}:{args.master_port}"
torch_run.run(args)
if __name__ == "__main__":
_main()