# Copyright (c) Huawei Technologies Co., Ltd. 2025-2026. All rights reserved.

from __future__ import annotations

import sys
import os
import kcal
import csv
import argparse
import time

from kcal import link_config


def create_link_desc() -> kcal.LinkDesc:
    nodes = [
        {"party": "alice", "address": "127.0.0.1:41929"},
        {"party": "bob", "address": "127.0.0.1:56815"}
    ]

    return link_config.create_link_from_nodes(
        nodes,
        link_id="kcal_psi_ub_test",
        connect_retry_times=3,
        connect_retry_interval_ms=1000,
        recv_timeout_ms=60000
    )


def psi_ub_demo(context, rank: int, file_path: str):
    try:
        input_file = os.path.join(file_path, f"{str(rank)}.csv")
        output_file = os.path.join(file_path, f"output_{str(rank)}.csv")
        op = kcal.create_psi_ub(context)
        start_time = time.time()
        ret, count = op.run(input_file, output_file)
        end_time = time.time()
        duration_ms = (end_time - start_time) * 1000
        print(f"psi_ub completed with return code: {ret}, count: {count}")
        print(f"psi_ub run cost: {duration_ms:.2f} ms")
    except Exception as e:
        print(f"Error: {e}")
        raise

def main(argv=None):
    parser = argparse.ArgumentParser(description="KCAL python wrapper demo.")
    parser.add_argument("--mode", type=str, required=True, choices=['memory', 'file'])
    parser.add_argument("--work_dir", type=str, default="./data")
    parser.add_argument("--use_sm_alg", default=False, action="store_true")
    parser.add_argument("--rank", type=int, required=True)
    args = parser.parse_args(argv)
    if args.rank not in [0, 1]:
        print("Error: --rank must be 0 or 1", file=sys.stderr)
        sys.exit(1)

    config = kcal.Config()
    config.nodeId = args.rank
    config.worldSize = 2
    config.fixBits = 3
    config.threadCount = 32
    config.useSMAlg = args.use_sm_alg

    yacl_link_desc = create_link_desc()
    print(f"[Rank {args.rank}] Creating context with yacl...")
    context = kcal.Context.create_with_link_config(config, yacl_link_desc, args.rank, args.mode == "file")
    psi_ub_demo(context, args.rank, args.work_dir)


if __name__ == "__main__":
    main()