import argparse
import sys
import os
import torch
import torchlight
from torchlight import import_class
import torch.multiprocessing as mp
if torch.__version__>= '1.8':
import torch_npu
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Processor collection')
processors = dict()
processors['recognition'] = import_class(
'processor.recognition.REC_Processor')
processors['demo_old'] = import_class('processor.demo_old.Demo')
processors['demo'] = import_class('processor.demo_realtime.DemoRealtime')
processors['demo_offline'] = import_class(
'processor.demo_offline.DemoOffline')
subparsers = parser.add_subparsers(dest='processor')
for k, p in processors.items():
subparsers.add_parser(k, parents=[p.get_parser()])
arg = parser.parse_args()
Processor = processors[arg.processor]
p = Processor(sys.argv[2:])
if p.arg.bin:
torch.npu.set_compile_mode(jit_compile=False)
devices = [p.arg.device] if isinstance(
p.arg.device, int) else list(p.arg.device)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '59629'
if len(devices) > 1 or "gpu" in p.arg.use_gpu_npu:
mp.spawn(p.parallel_train, nprocs=len(devices))
else:
p.parallel_train(p.arg.device[0])