import argparse
import subprocess
import multiprocessing

def run_script(script_name, args):
    cmd = ["python", script_name] + args
    print(f"Running: {' '.join(cmd)}")
    subprocess.run(cmd, check=True)

def start_process(rank, mode, work_dir, operate_type, use_sm_alg, n_part):
    base_args = ["--rank", str(rank), "--work_dir", work_dir]
    if use_sm_alg:
        base_args.append("--use_sm_alg")
    if n_part:
        base_args.append("--n_part")
    if mode:
        base_args.append("--mode")
        base_args.append(mode)

    if operate_type in ["PSI", "ADD", "SUB", "MUL", "DIV", "LESS", "LESS_EQUAL", "GREATER", "GREATER_EQUAL", "EQUAL", "NO_EQUAL", "SUM", "AVG", "MAX", "MIN", "ASCEND_SORT", "DESCEND_SORT"]:
        if operate_type != "PSI":
            arith_args = base_args + ["--operate_type", operate_type]
            run_script("test/yacl/arith_demo.py",  arith_args)
        else:
            if mode == "memory":
                run_script("test/yacl/psi_demo.py", base_args)
            else:
                print("PSI 操作仅支持memory")
    elif operate_type == "PSI_UB":
        run_script("test/yacl/psi_ub_demo.py", base_args)
    elif operate_type == "MAKE_SHARE":
        run_script("test/yacl/make_share_demo.py", base_args)
    elif operate_type == "REVEAL_SHARE":
        run_script("test/yacl/reveal_share_demo.py", base_args)
    elif operate_type == "PIR":
        run_script("test/yacl/pir_demo.py", base_args)
    else:
        print(f"未知操作类型:{operate_type}")

def main():
    parser = argparse.ArgumentParser(description="YCAL python wrapper demo.")
    parser.add_argument("--mode", type=str, required=True, choices=['memory', 'file'])
    parser.add_argument("--work_dir", type=str, default="./data")
    parser.add_argument("--operate_type", type=str, required=True, choices=[
        "PIR","PSI_UB", "MAKE_SHARE", "REVEAL_SHARE","PSI", "ADD", "SUB", "MUL", "DIV",
        "LESS", "LESS_EQUAL", "GREATER", "GREATER_EQUAL", "EQUAL", "NO_EQUAL", "SUM",
        "AVG", "MAX", "MIN", "ASCEND_SORT", "DESCEND_SORT"])
    parser.add_argument("--use_sm_alg", default=False, action="store_true")
    parser.add_argument("--n_part", required=False, action="store_true")

    args = parser.parse_args()

    if args.n_part:
        ranks = [0, 1, 2]
    else:
        ranks = [0, 1]

    # 启动多进程
    processes = []
    for rank in ranks:
        p = multiprocessing.Process(target=start_process, args=(rank, args.mode, args.work_dir, args.operate_type, args.use_sm_alg, args.n_part))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

if __name__ == "__main__":
    main()