from __future__ import annotations
import sys
import os
import kcal
import csv
import argparse
import time
from kcal import link_config
def create_link_desc() -> kcal.LinkDesc:
nodes = [
{"party": "alice", "address": "127.0.0.1:41929"},
{"party": "bob", "address": "127.0.0.1:56815"}
]
return link_config.create_link_from_nodes(
nodes,
link_id="kcal_psi_ub_test",
connect_retry_times=3,
connect_retry_interval_ms=1000,
recv_timeout_ms=60000
)
def psi_ub_demo(context, rank: int, file_path: str):
try:
input_file = os.path.join(file_path, f"{str(rank)}.csv")
output_file = os.path.join(file_path, f"output_{str(rank)}.csv")
op = kcal.create_psi_ub(context)
start_time = time.time()
ret, count = op.run(input_file, output_file)
end_time = time.time()
duration_ms = (end_time - start_time) * 1000
print(f"psi_ub completed with return code: {ret}, count: {count}")
print(f"psi_ub run cost: {duration_ms:.2f} ms")
except Exception as e:
print(f"Error: {e}")
raise
def main(argv=None):
parser = argparse.ArgumentParser(description="KCAL python wrapper demo.")
parser.add_argument("--mode", type=str, required=True, choices=['memory', 'file'])
parser.add_argument("--work_dir", type=str, default="./data")
parser.add_argument("--use_sm_alg", default=False, action="store_true")
parser.add_argument("--rank", type=int, required=True)
args = parser.parse_args(argv)
if args.rank not in [0, 1]:
print("Error: --rank must be 0 or 1", file=sys.stderr)
sys.exit(1)
config = kcal.Config()
config.nodeId = args.rank
config.worldSize = 2
config.fixBits = 3
config.threadCount = 32
config.useSMAlg = args.use_sm_alg
yacl_link_desc = create_link_desc()
print(f"[Rank {args.rank}] Creating context with yacl...")
context = kcal.Context.create_with_link_config(config, yacl_link_desc, args.rank, args.mode == "file")
psi_ub_demo(context, args.rank, args.work_dir)
if __name__ == "__main__":
main()