import argparse
import subprocess
import multiprocessing
def run_script(script_name, args):
cmd = ["python", script_name] + args
print(f"Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
def start_process(rank, mode, work_dir, operate_type, use_sm_alg, n_part):
base_args = ["--rank", str(rank), "--work_dir", work_dir]
if use_sm_alg:
base_args.append("--use_sm_alg")
if n_part:
base_args.append("--n_part")
if mode:
base_args.append("--mode")
base_args.append(mode)
if operate_type in ["PSI", "ADD", "SUB", "MUL", "DIV", "LESS", "LESS_EQUAL", "GREATER", "GREATER_EQUAL", "EQUAL", "NO_EQUAL", "SUM", "AVG", "MAX", "MIN", "ASCEND_SORT", "DESCEND_SORT"]:
if operate_type != "PSI":
arith_args = base_args + ["--operate_type", operate_type]
run_script("test/yacl/arith_demo.py", arith_args)
else:
if mode == "memory":
run_script("test/yacl/psi_demo.py", base_args)
else:
print("PSI 操作仅支持memory")
elif operate_type == "PSI_UB":
run_script("test/yacl/psi_ub_demo.py", base_args)
elif operate_type == "MAKE_SHARE":
run_script("test/yacl/make_share_demo.py", base_args)
elif operate_type == "REVEAL_SHARE":
run_script("test/yacl/reveal_share_demo.py", base_args)
elif operate_type == "PIR":
run_script("test/yacl/pir_demo.py", base_args)
else:
print(f"未知操作类型:{operate_type}")
def main():
parser = argparse.ArgumentParser(description="YCAL python wrapper demo.")
parser.add_argument("--mode", type=str, required=True, choices=['memory', 'file'])
parser.add_argument("--work_dir", type=str, default="./data")
parser.add_argument("--operate_type", type=str, required=True, choices=[
"PIR","PSI_UB", "MAKE_SHARE", "REVEAL_SHARE","PSI", "ADD", "SUB", "MUL", "DIV",
"LESS", "LESS_EQUAL", "GREATER", "GREATER_EQUAL", "EQUAL", "NO_EQUAL", "SUM",
"AVG", "MAX", "MIN", "ASCEND_SORT", "DESCEND_SORT"])
parser.add_argument("--use_sm_alg", default=False, action="store_true")
parser.add_argument("--n_part", required=False, action="store_true")
args = parser.parse_args()
if args.n_part:
ranks = [0, 1, 2]
else:
ranks = [0, 1]
processes = []
for rank in ranks:
p = multiprocessing.Process(target=start_process, args=(rank, args.mode, args.work_dir, args.operate_type, args.use_sm_alg, args.n_part))
processes.append(p)
p.start()
for p in processes:
p.join()
if __name__ == "__main__":
main()