from __future__ import annotations
import sys
import os
import kcal
import csv
import argparse
import time
from kcal import link_config
def create_link_desc(is_3rd: bool) -> kcal.LinkDesc:
nodes = [
{"party": "alice", "address": "127.0.0.1:41929"},
{"party": "bob", "address": "127.0.0.1:56815"},
{"party": "carol", "address": "127.0.0.1:56816"}
] if is_3rd else [
{"party": "alice", "address": "127.0.0.1:41929"},
{"party": "bob", "address": "127.0.0.1:56815"}
]
return link_config.create_link_from_nodes(
nodes,
link_id="kcal_arith_test",
connect_retry_times=3,
connect_retry_interval_ms=1000,
recv_timeout_ms=60000
)
def create_context(args):
config = kcal.Config()
config.nodeId = args.rank
config.worldSize = 3 if args.n_part else 2
config.fixBits = 3
config.threadCount = 32
config.useSMAlg = args.use_sm_alg
yacl_link_desc = create_link_desc(args.n_part)
return kcal.Context.create_with_link_config(config, yacl_link_desc, args.rank, args.mode == "file")
def test_basic_arithmetic(context, rank: int, file_path: str, type: str):
print("\n=== Testing Basic Arithmetic Operations ===")
make_share_op = kcal.create_make_share(context)
reveal_share_op = kcal.create_reveal_share(context)
start_time = time.time()
input_data = []
input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
with open(input_file, 'r', newline = '', encoding = 'utf-8') as f1:
reader = csv.reader(f1)
for row in reader:
input_data.append(float(row[0]))
print(f"\ninput_data head 10: {input_data[:10]}, total_len: {len(input_data)}")
share1 = kcal.MpcShare()
share2 = kcal.MpcShare()
if type in {'ASCEND_SORT', 'DESCEND_SORT', 'SUM', 'AVG', 'MAX', 'MIN'}:
if rank == 0:
make_share_op.run(input_data, 0, share1)
else:
make_share_op.run(input_data, 1, share1)
else:
if rank == 0:
make_share_op.run(input_data, 0, share1)
make_share_op.run([], 1, share2)
else:
make_share_op.run([], 1, share1)
make_share_op.run(input_data, 0, share2)
if type == 'ADD':
add_op = kcal.create_mpc(context, kcal.AlgorithmsType.ADD)
add_out_share = kcal.MpcShare()
add_op.run([share1, share2], add_out_share)
add_output = []
reveal_share_op.run(add_out_share, add_output)
print(f"ADD result: {add_output}")
elif type == 'SUB':
sub_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUB)
sub_out_share = kcal.MpcShare()
sub_op.run([share1, share2], sub_out_share)
sub_output = []
reveal_share_op.run(sub_out_share, sub_output)
print(f"SUB result: {sub_output}")
elif type == 'DIV':
div_op = kcal.create_mpc(context, kcal.AlgorithmsType.DIV)
div_out_share = kcal.MpcShare()
div_op.run([share1, share2], div_out_share)
div_output = []
reveal_share_op.run(div_out_share, div_output)
print(f"DIV result: {div_output}")
elif type == 'SUM':
sum_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUM)
sum_out_share = kcal.MpcShare()
sum_op.run([share1], sum_out_share)
sum_output = []
reveal_share_op.run(sum_out_share, sum_output)
print(f"SUM result: {sum_output}")
elif type == 'AVG':
avg_op = kcal.create_mpc(context, kcal.AlgorithmsType.AVG)
avg_out_share = kcal.MpcShare()
avg_op.run([share1], avg_out_share)
avg_output = []
reveal_share_op.run(avg_out_share, avg_output)
print(f"AVG result: {avg_output}")
elif type == 'MAX':
max_op = kcal.create_mpc(context, kcal.AlgorithmsType.MAX)
max_out_share = kcal.MpcShare()
max_op.run([share1], max_out_share)
max_output = []
reveal_share_op.run(max_out_share, max_output)
print(f"MAX result: {max_output}")
elif type == 'MIN':
min_op = kcal.create_mpc(context, kcal.AlgorithmsType.MIN)
min_out_share = kcal.MpcShare()
min_op.run([share1], min_out_share)
min_output = []
reveal_share_op.run(min_out_share, min_output)
print(f"MIN result: {min_output}")
elif type == 'ASCEND_SORT':
ascend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.ASCEND_SORT)
ascend_sort_out_share = kcal.MpcShare()
ascend_sort_op.run([share1], ascend_sort_out_share)
ascend_sort_output = []
reveal_share_op.run(ascend_sort_out_share, ascend_sort_output)
print(f"ASCEND_SORT result: {ascend_sort_output}")
elif type == 'DESCEND_SORT':
descend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.DESCEND_SORT)
descend_sort_out_share = kcal.MpcShare()
descend_sort_op.run([share1], descend_sort_out_share)
descend_sort_output = []
reveal_share_op.run(descend_sort_out_share, descend_sort_output)
print(f"DESCEND_SORT result: {descend_sort_output}")
else:
raise ValueError(f"Unkown operation type: {type}")
end_time = time.time()
duration_ms = (end_time - start_time) * 1000
print(f"Basics arithmetic test completed in: {duration_ms:.2f} ms")
def test_basic_arithmetic_file(context, rank: int, file_path: str, type: str):
print("\n=== Testing Basic Arithmetic File Operations ===")
make_share_op = kcal.create_make_share(context)
reveal_share_op = kcal.create_reveal_share(context)
start_time = time.time()
input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
share_file0 = os.path.join(file_path, f"arithmetic_{str(rank)}1_result.csv")
share_file1 = os.path.join(file_path, f"arithmetic_{str(rank)}2_result.csv")
calculate_file = os.path.join(file_path, f"arithmetic_calculate_result_{str(rank)}.csv")
reveal_file = os.path.join(file_path, f"arithmetic_reveal_result_{str(rank)}.csv")
if type in {'ASCEND_SORT', 'DESCEND_SORT', 'SUM', 'AVG', 'MAX', 'MIN'}:
if rank == 0:
make_share_op.run(input_file, 0, share_file0)
else:
make_share_op.run(input_file, 1, share_file0)
else:
if rank == 0:
make_share_op.run(input_file, 0, share_file0)
make_share_op.run(input_file, 1, share_file1)
else:
make_share_op.run(input_file, 1, share_file0)
make_share_op.run(input_file, 0, share_file1)
if type == 'ADD':
add_op = kcal.create_mpc(context, kcal.AlgorithmsType.ADD)
add_op.run([share_file0, share_file1], calculate_file)
elif type == 'SUB':
sub_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUB)
sub_op.run([share_file0, share_file1], calculate_file)
elif type == 'MUL':
mul_op = kcal.create_mpc(context, kcal.AlgorithmsType.MUL)
mul_op.run([share_file0, share_file1], calculate_file)
elif type == 'DIV':
div_op = kcal.create_mpc(context, kcal.AlgorithmsType.DIV)
div_op.run([share_file0, share_file1], calculate_file)
elif type == 'LESS':
less_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS)
less_op.run([share_file0, share_file1], calculate_file)
elif type == 'LESS_EQUAL':
less_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS_EQUAL)
less_equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'GREATER':
greater_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER)
greater_op.run([share_file0, share_file1], calculate_file)
elif type == 'GREATER_EQUAL':
greater_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER_EQUAL)
greater_equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'EQUAL':
equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.EQUAL)
equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'NO_EQUAL':
no_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.NO_EQUAL)
no_equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'SUM':
sum_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUM)
sum_op.run([share_file0], calculate_file)
elif type == 'AVG':
avg_op = kcal.create_mpc(context, kcal.AlgorithmsType.AVG)
avg_op.run([share_file0], calculate_file)
elif type == 'MAX':
max_op = kcal.create_mpc(context, kcal.AlgorithmsType.MAX)
max_op.run([share_file0], calculate_file)
elif type == 'MIN':
min_op = kcal.create_mpc(context, kcal.AlgorithmsType.MIN)
min_op.run([share_file0], calculate_file)
elif type == 'ASCEND_SORT':
ascend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.ASCEND_SORT)
ascend_sort_op.run([share_file0], calculate_file)
elif type == 'DESCEND_SORT':
descend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.DESCEND_SORT)
descend_sort_op.run([share_file0], calculate_file)
else:
raise ValueError(f"Unkown operation type: {type}")
reveal_share_op.run(calculate_file, reveal_file)
end_time = time.time()
duration_ms = (end_time - start_time) * 1000
print(f"Basics arithmetic test completed in: {duration_ms:.2f} ms")
def test_basic_3rd_arithmetic(context, rank: int, file_path: str, type: str):
print("\n=== Testing Basic Arithmetic File Operations ===")
make_share_op = kcal.create_make_share(context)
reveal_share_op = kcal.create_reveal_share(context)
less_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS)
less_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS_EQUAL)
greater_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER)
greater_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER_EQUAL)
equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.EQUAL)
no_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.NO_EQUAL)
mul_op = kcal.create_mpc(context, kcal.AlgorithmsType.MUL)
input_data = []
input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
with open(input_file, 'r', newline = '', encoding = 'utf-8') as f1:
reader = csv.reader(f1)
for row in reader:
input_data.append(float(row[0]))
print(f"\ninput_data head 10: {input_data[:10]}, total_len: {len(input_data)}")
start_time = time.time()
share1 = kcal.MpcShare()
share2 = kcal.MpcShare()
if rank == 0:
print("\n rank 0: Processing arithmetic operations...")
make_share_op.run(input_data, 0, share1)
make_share_op.run([], 1, share2)
elif rank == 1:
print("\n rank 1: Processing arithmetic operations...")
make_share_op.run([], 1, share1)
make_share_op.run(input_data, 0, share2)
else:
print("\n rank 2: Processing arithmetic operations...")
make_share_op.run([], 1, share1)
make_share_op.run([], 1, share2)
output = []
if type == 'MUL':
mul_out_share = kcal.MpcShare()
mul_op.run([share1, share2], mul_out_share)
reveal_share_op.run(mul_out_share, output)
elif type == 'LESS':
less_out_share = kcal.MpcShare()
less_op.run([share1, share2], less_out_share)
reveal_share_op.run(less_out_share, output)
elif type == 'LESS_EQUAL':
less_equal_out_share = kcal.MpcShare()
less_equal_op.run([share1, share2], less_equal_out_share)
reveal_share_op.run(less_equal_out_share, output)
elif type == 'GREATER':
greater_out_share = kcal.MpcShare()
greater_op.run([share1, share2], greater_out_share)
reveal_share_op.run(greater_out_share, output)
elif type == 'GREATER_EQUAL':
greater_equal_out_share = kcal.MpcShare()
greater_equal_op.run([share1, share2], greater_equal_out_share)
reveal_share_op.run(greater_equal_out_share, output)
elif type == 'EQUAL':
equal_out_share = kcal.MpcShare()
equal_op.run([share1, share2], equal_out_share)
reveal_share_op.run(equal_out_share, output)
elif type == 'NO_EQUAL':
no_equal_out_share = kcal.MpcShare()
no_equal_op.run([share1, share2], no_equal_out_share)
reveal_share_op.run(no_equal_out_share, output)
end_time = time.time()
print(f"\n{type} cost: {(end_time - start_time) * 1000:.2f} ms")
print(f"\n{type} result top 10: {output[:10]}")
def test_basic_3rd_arithmetic_file(context, rank: int, file_path: str, type: str):
print("\n=== Testing Basic Arithmetic File Operations ===")
make_share_op = kcal.create_make_share(context)
reveal_share_op = kcal.create_reveal_share(context)
less_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS)
less_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS_EQUAL)
greater_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER)
greater_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER_EQUAL)
equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.EQUAL)
no_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.NO_EQUAL)
mul_op = kcal.create_mpc(context, kcal.AlgorithmsType.MUL)
input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
share_file0 = os.path.join(file_path, f"arithmetic_{str(rank)}1_result.csv")
share_file1 = os.path.join(file_path, f"arithmetic_{str(rank)}2_result.csv")
calculate_file = os.path.join(file_path, f"arithmetic_calculate_result_{str(rank)}.csv")
reveal_file = os.path.join(file_path, f"arithmetic_reveal_result_{str(rank)}.csv")
if rank == 0:
print("\n rank 0: Processing arithmetic operations...")
make_share_op.run(input_file, 0, share_file0)
make_share_op.run(input_file, 1, share_file1)
elif rank == 1:
print("\n rank 1: Processing arithmetic operations...")
make_share_op.run(input_file, 1, share_file0)
make_share_op.run(input_file, 0, share_file1)
else:
print("\n rank 0: Processing arithmetic operations...")
make_share_op.run(input_file, 1, share_file0)
make_share_op.run(input_file, 1, share_file1)
start_time = time.time()
if type == 'MUL':
mul_op.run([share_file0, share_file1], calculate_file)
elif type == 'LESS':
less_op.run([share_file0, share_file1], calculate_file)
elif type == 'LESS_EQUAL':
less_equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'GREATER':
greater_op.run([share_file0, share_file1], calculate_file)
elif type == 'GREATER_EQUAL':
greater_equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'EQUAL':
equal_op.run([share_file0, share_file1], calculate_file)
elif type == 'NO_EQUAL':
no_equal_op.run([share_file0, share_file1], calculate_file)
reveal_share_op.run(calculate_file, reveal_file)
end_time = time.time()
duration_ms = (end_time - start_time) * 1000
print(f"Basics arithmetic test completed in: {duration_ms:.2f} ms")
def create_parser():
parser = argparse.ArgumentParser(description="KCAL python wrapper demo.")
try:
parser.add_argument("--mode", type=str, required=True, choices=['memory', 'file'])
parser.add_argument("--work_dir", type=str, default="./data")
parser.add_argument("--operate_type", type=str, required=True)
parser.add_argument("--use_sm_alg", default=False, action="store_true")
parser.add_argument("--n_part", required=False, action="store_true")
parser.add_argument("--rank", type=int, required=True)
return parser
except argparse.ArgumentParser:
print("\nParam error.")
sys.exit(1)
def do_memory_func(args):
context = create_context(args)
if args.operate_type in {"MUL", "LESS", "LESS_EQUAL", "GREATER", "GREATER_EQUAL", "EQUAL", "NO_EQUAL"}:
test_basic_3rd_arithmetic(context, args.rank, args.work_dir, args.operate_type)
else:
test_basic_arithmetic(context, args.rank, args.work_dir, args.operate_type)
def do_file_func(args):
context = create_context(args)
if args.n_part:
test_basic_3rd_arithmetic_file(context, args.rank, args.work_dir, args.operate_type)
else:
test_basic_arithmetic_file(context, args.rank, args.work_dir, args.operate_type)
def main(argv=None):
parser = create_parser()
args = parser.parse_args(argv)
if args.rank not in [0, 1, 2]:
print("Error: --rank must be 0 or 1 or 2", file=sys.stderr)
sys.exit(1)
if args.mode == "memory":
do_memory_func(args)
else:
do_file_func(args)
if __name__ == "__main__":
main()