import multiprocessing as mp
import os
import sys
import torch
import torch_npu
from mutil_process import read_worker, write_worker
if __name__ == "__main__":
testcase = sys.argv[1]
process_count = int(sys.argv[2])
batch_size = int(sys.argv[3])
block_size = int(sys.argv[4])
call_count = int(sys.argv[5])
data_dim = int(sys.argv[6])
backend = sys.argv[7]
local_type = sys.argv[8]
print(f"主进程 PID: {os.getpid()}, {testcase=}, {process_count=}, {batch_size=}, {block_size=}, "
f"{call_count=}, {data_dim=}, {backend=}, {local_type=}")
sync = mp.Barrier(process_count)
process = []
for index in range(process_count):
if testcase == "read":
p = mp.Process(target=read_worker, args=(index, batch_size, block_size, call_count, data_dim,
backend, local_type, process_count, sync, ))
p.start()
process.append(p)
elif testcase == "write":
p = mp.Process(target=write_worker, args=(index, batch_size, block_size, call_count, data_dim,
backend, local_type, process_count, sync, ))
p.start()
process.append(p)
else:
print(f"{testcase=} error")
for i in range(process_count):
process[i].join()
print("所有子进程结束。")