# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

from __future__ import annotations

import sys
import os
import kcal
import csv
import argparse
import time

from kcal import link_config

def create_link_desc(is_3rd: bool) -> kcal.LinkDesc:
    nodes = [
        {"party": "alice", "address": "127.0.0.1:41929"},
        {"party": "bob", "address": "127.0.0.1:56815"},
        {"party": "carol", "address": "127.0.0.1:56816"}
    ] if is_3rd else [
        {"party": "alice", "address": "127.0.0.1:41929"},
        {"party": "bob", "address": "127.0.0.1:56815"}
    ]
    
    return link_config.create_link_from_nodes(
        nodes,
        link_id="kcal_arith_test",
        connect_retry_times=3,
        connect_retry_interval_ms=1000,
        recv_timeout_ms=60000
    )

def create_context(args):
    config = kcal.Config()
    config.nodeId = args.rank
    config.worldSize = 3 if args.n_part else 2
    config.fixBits = 3
    config.threadCount = 32
    config.useSMAlg = args.use_sm_alg
    yacl_link_desc = create_link_desc(args.n_part)
    return kcal.Context.create_with_link_config(config, yacl_link_desc, args.rank, args.mode == "file")

# 两方算子内存调用(不包括乘法、比较类算子,合并到三方算子内存调用中)
def test_basic_arithmetic(context, rank: int, file_path: str, type: str):
    print("\n=== Testing Basic Arithmetic Operations ===")
    
    make_share_op = kcal.create_make_share(context)
    reveal_share_op = kcal.create_reveal_share(context)
    start_time = time.time()
    input_data = []
    input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
    with open(input_file, 'r', newline = '', encoding = 'utf-8') as f1:
        reader = csv.reader(f1)
        for row in reader:
            input_data.append(float(row[0]))
    print(f"\ninput_data head 10: {input_data[:10]}, total_len: {len(input_data)}")
    share1 = kcal.MpcShare()
    share2 = kcal.MpcShare()

    if type in {'ASCEND_SORT', 'DESCEND_SORT', 'SUM', 'AVG', 'MAX', 'MIN'}:
        if rank == 0:
            make_share_op.run(input_data, 0, share1)
        else:
            make_share_op.run(input_data, 1, share1)
    else:
        if rank == 0:
            make_share_op.run(input_data, 0, share1)
            make_share_op.run([], 1, share2)
        else:
            make_share_op.run([], 1, share1)
            make_share_op.run(input_data, 0, share2)
    
    if type == 'ADD':
        add_op = kcal.create_mpc(context, kcal.AlgorithmsType.ADD)
        add_out_share = kcal.MpcShare()
        add_op.run([share1, share2], add_out_share)
        add_output = []
        reveal_share_op.run(add_out_share, add_output)
        print(f"ADD result: {add_output}")
    elif type == 'SUB':
        sub_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUB)
        sub_out_share = kcal.MpcShare()
        sub_op.run([share1, share2], sub_out_share)
        sub_output = []
        reveal_share_op.run(sub_out_share, sub_output)
        print(f"SUB result: {sub_output}")
    elif type == 'DIV':
        div_op = kcal.create_mpc(context, kcal.AlgorithmsType.DIV)
        div_out_share = kcal.MpcShare()
        div_op.run([share1, share2], div_out_share)
        div_output = []
        reveal_share_op.run(div_out_share, div_output)
        print(f"DIV result: {div_output}")
    elif type == 'SUM':
        sum_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUM)
        sum_out_share = kcal.MpcShare()
        sum_op.run([share1], sum_out_share)
        sum_output = []
        reveal_share_op.run(sum_out_share, sum_output)
        print(f"SUM result: {sum_output}")
    elif type == 'AVG':
        avg_op = kcal.create_mpc(context, kcal.AlgorithmsType.AVG)
        avg_out_share = kcal.MpcShare()
        avg_op.run([share1], avg_out_share)
        avg_output = []
        reveal_share_op.run(avg_out_share, avg_output)
        print(f"AVG result: {avg_output}")
    elif type == 'MAX':
        max_op = kcal.create_mpc(context, kcal.AlgorithmsType.MAX)
        max_out_share = kcal.MpcShare()
        max_op.run([share1], max_out_share)
        max_output = []
        reveal_share_op.run(max_out_share, max_output)
        print(f"MAX result: {max_output}")
    elif type == 'MIN':
        min_op = kcal.create_mpc(context, kcal.AlgorithmsType.MIN)
        min_out_share = kcal.MpcShare()
        min_op.run([share1], min_out_share)
        min_output = []
        reveal_share_op.run(min_out_share, min_output)
        print(f"MIN result: {min_output}")
    elif type == 'ASCEND_SORT':
        ascend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.ASCEND_SORT)
        ascend_sort_out_share = kcal.MpcShare()
        ascend_sort_op.run([share1], ascend_sort_out_share)
        ascend_sort_output = []
        reveal_share_op.run(ascend_sort_out_share, ascend_sort_output)
        print(f"ASCEND_SORT result: {ascend_sort_output}")
    elif type == 'DESCEND_SORT':
        descend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.DESCEND_SORT)
        descend_sort_out_share = kcal.MpcShare()
        descend_sort_op.run([share1], descend_sort_out_share)
        descend_sort_output = []
        reveal_share_op.run(descend_sort_out_share, descend_sort_output)
        print(f"DESCEND_SORT result: {descend_sort_output}")
    else:
        raise ValueError(f"Unkown operation type: {type}")
    
    end_time = time.time()
    duration_ms = (end_time - start_time) * 1000
    print(f"Basics arithmetic test completed in: {duration_ms:.2f} ms")
    
    
# 两方算子文件调用    
def test_basic_arithmetic_file(context, rank: int, file_path: str, type: str):
    print("\n=== Testing Basic Arithmetic File Operations ===")

    make_share_op = kcal.create_make_share(context)
    reveal_share_op = kcal.create_reveal_share(context)
    start_time = time.time()
    input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
    share_file0 = os.path.join(file_path, f"arithmetic_{str(rank)}1_result.csv")
    share_file1 = os.path.join(file_path, f"arithmetic_{str(rank)}2_result.csv")
    calculate_file = os.path.join(file_path, f"arithmetic_calculate_result_{str(rank)}.csv")
    reveal_file = os.path.join(file_path, f"arithmetic_reveal_result_{str(rank)}.csv")

    if type in {'ASCEND_SORT', 'DESCEND_SORT', 'SUM', 'AVG', 'MAX', 'MIN'}:
        if rank == 0:
            make_share_op.run(input_file, 0, share_file0)
        else:
            make_share_op.run(input_file, 1, share_file0)
    else:
        if rank == 0:
            make_share_op.run(input_file, 0, share_file0)
            make_share_op.run(input_file, 1, share_file1)
        else:
            make_share_op.run(input_file, 1, share_file0)
            make_share_op.run(input_file, 0, share_file1)

    if type == 'ADD':
        add_op = kcal.create_mpc(context, kcal.AlgorithmsType.ADD)
        add_op.run([share_file0, share_file1], calculate_file)
    elif type == 'SUB':
        sub_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUB)
        sub_op.run([share_file0, share_file1], calculate_file)
    elif type == 'MUL':
        mul_op = kcal.create_mpc(context, kcal.AlgorithmsType.MUL)
        mul_op.run([share_file0, share_file1], calculate_file)
    elif type == 'DIV':
        div_op = kcal.create_mpc(context, kcal.AlgorithmsType.DIV)
        div_op.run([share_file0, share_file1], calculate_file)
    elif type == 'LESS':
        less_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS)
        less_op.run([share_file0, share_file1], calculate_file)
    elif type == 'LESS_EQUAL':
        less_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS_EQUAL)
        less_equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'GREATER':
        greater_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER)
        greater_op.run([share_file0, share_file1], calculate_file)
    elif type == 'GREATER_EQUAL':
        greater_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER_EQUAL)
        greater_equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'EQUAL':
        equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.EQUAL)
        equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'NO_EQUAL':
        no_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.NO_EQUAL)
        no_equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'SUM':
        sum_op = kcal.create_mpc(context, kcal.AlgorithmsType.SUM)
        sum_op.run([share_file0], calculate_file)
    elif type == 'AVG':
        avg_op = kcal.create_mpc(context, kcal.AlgorithmsType.AVG)
        avg_op.run([share_file0], calculate_file)
    elif type == 'MAX':
        max_op = kcal.create_mpc(context, kcal.AlgorithmsType.MAX)
        max_op.run([share_file0], calculate_file)
    elif type == 'MIN':
        min_op = kcal.create_mpc(context, kcal.AlgorithmsType.MIN)
        min_op.run([share_file0], calculate_file)
    elif type == 'ASCEND_SORT':
        ascend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.ASCEND_SORT)
        ascend_sort_op.run([share_file0], calculate_file)
    elif type == 'DESCEND_SORT':
        descend_sort_op = kcal.create_mpc(context, kcal.AlgorithmsType.DESCEND_SORT)
        descend_sort_op.run([share_file0], calculate_file)
    else:
        raise ValueError(f"Unkown operation type: {type}")

    reveal_share_op.run(calculate_file, reveal_file)
    end_time = time.time()
    duration_ms = (end_time - start_time) * 1000
    print(f"Basics arithmetic test completed in: {duration_ms:.2f} ms")        
    

# 三方算子内存调用
def test_basic_3rd_arithmetic(context, rank: int, file_path: str, type: str):
    print("\n=== Testing Basic Arithmetic File Operations ===")

    make_share_op = kcal.create_make_share(context)
    reveal_share_op = kcal.create_reveal_share(context)

    less_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS)
    less_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS_EQUAL)
    greater_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER)
    greater_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER_EQUAL)
    equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.EQUAL)
    no_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.NO_EQUAL)
    mul_op = kcal.create_mpc(context, kcal.AlgorithmsType.MUL)
    
    input_data = []
    input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
    with open(input_file, 'r', newline = '', encoding = 'utf-8') as f1:
        reader = csv.reader(f1)
        for row in reader:
            input_data.append(float(row[0]))
    print(f"\ninput_data head 10: {input_data[:10]}, total_len: {len(input_data)}")
    start_time = time.time()
    share1 = kcal.MpcShare()
    share2 = kcal.MpcShare()

    if rank == 0:
        print("\n rank 0: Processing arithmetic operations...")
        make_share_op.run(input_data, 0, share1)
        make_share_op.run([], 1, share2)
    elif rank == 1:
        print("\n rank 1: Processing arithmetic operations...")
        make_share_op.run([], 1, share1)
        make_share_op.run(input_data, 0, share2)
    else:
        print("\n rank 2: Processing arithmetic operations...")
        make_share_op.run([], 1, share1)
        make_share_op.run([], 1, share2)

    output = []
    if type == 'MUL':
        mul_out_share = kcal.MpcShare()
        mul_op.run([share1, share2], mul_out_share)
        reveal_share_op.run(mul_out_share, output)
    elif type == 'LESS':
        less_out_share = kcal.MpcShare()
        less_op.run([share1, share2], less_out_share)
        reveal_share_op.run(less_out_share, output)
    elif type == 'LESS_EQUAL':
        less_equal_out_share = kcal.MpcShare()
        less_equal_op.run([share1, share2], less_equal_out_share)
        reveal_share_op.run(less_equal_out_share, output)
    elif type == 'GREATER':
        greater_out_share = kcal.MpcShare()
        greater_op.run([share1, share2], greater_out_share)
        reveal_share_op.run(greater_out_share, output)
    elif type == 'GREATER_EQUAL':
        greater_equal_out_share = kcal.MpcShare()
        greater_equal_op.run([share1, share2], greater_equal_out_share)
        reveal_share_op.run(greater_equal_out_share, output)
    elif type == 'EQUAL':
        equal_out_share = kcal.MpcShare()
        equal_op.run([share1, share2], equal_out_share)
        reveal_share_op.run(equal_out_share, output)
    elif type == 'NO_EQUAL':
        no_equal_out_share = kcal.MpcShare()
        no_equal_op.run([share1, share2], no_equal_out_share)
        reveal_share_op.run(no_equal_out_share, output)
    
    end_time = time.time()    
    print(f"\n{type} cost: {(end_time - start_time) * 1000:.2f} ms")
    print(f"\n{type} result top 10: {output[:10]}")
    

#三方算子文件调用
def test_basic_3rd_arithmetic_file(context, rank: int, file_path: str, type: str):
    print("\n=== Testing Basic Arithmetic File Operations ===")

    make_share_op = kcal.create_make_share(context)
    reveal_share_op = kcal.create_reveal_share(context)

    less_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS)
    less_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.LESS_EQUAL)
    greater_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER)
    greater_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.GREATER_EQUAL)
    equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.EQUAL)
    no_equal_op = kcal.create_mpc(context, kcal.AlgorithmsType.NO_EQUAL)
    mul_op = kcal.create_mpc(context, kcal.AlgorithmsType.MUL)

    input_file = os.path.join(file_path, f"arith_{str(rank)}.csv")
    share_file0 = os.path.join(file_path, f"arithmetic_{str(rank)}1_result.csv")
    share_file1 = os.path.join(file_path, f"arithmetic_{str(rank)}2_result.csv")
    calculate_file = os.path.join(file_path, f"arithmetic_calculate_result_{str(rank)}.csv")
    reveal_file = os.path.join(file_path, f"arithmetic_reveal_result_{str(rank)}.csv")

    if rank == 0:
        print("\n rank 0: Processing arithmetic operations...")
        make_share_op.run(input_file, 0, share_file0)
        make_share_op.run(input_file, 1, share_file1)
    elif rank == 1:
        print("\n rank 1: Processing arithmetic operations...")
        make_share_op.run(input_file, 1, share_file0)
        make_share_op.run(input_file, 0, share_file1)
    else:
        print("\n rank 0: Processing arithmetic operations...")
        make_share_op.run(input_file, 1, share_file0)
        make_share_op.run(input_file, 1, share_file1)
    
    start_time = time.time()
    if type == 'MUL':
        mul_op.run([share_file0, share_file1], calculate_file)
    elif type == 'LESS':
        less_op.run([share_file0, share_file1], calculate_file)
    elif type == 'LESS_EQUAL':
        less_equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'GREATER':
        greater_op.run([share_file0, share_file1], calculate_file)
    elif type == 'GREATER_EQUAL':
        greater_equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'EQUAL':
        equal_op.run([share_file0, share_file1], calculate_file)
    elif type == 'NO_EQUAL':
        no_equal_op.run([share_file0, share_file1], calculate_file)

    reveal_share_op.run(calculate_file, reveal_file)
    end_time = time.time()
    duration_ms = (end_time - start_time) * 1000
    print(f"Basics arithmetic test completed in: {duration_ms:.2f} ms")


def create_parser():
    parser = argparse.ArgumentParser(description="KCAL python wrapper demo.")
    try:
        parser.add_argument("--mode", type=str, required=True, choices=['memory', 'file'])
        parser.add_argument("--work_dir", type=str, default="./data")
        parser.add_argument("--operate_type", type=str, required=True)
        parser.add_argument("--use_sm_alg", default=False, action="store_true")
        parser.add_argument("--n_part", required=False, action="store_true")
        parser.add_argument("--rank", type=int, required=True)

        return parser
    except argparse.ArgumentParser:
        print("\nParam error.")
        sys.exit(1)


def do_memory_func(args):
    context = create_context(args)
    if args.operate_type in {"MUL", "LESS", "LESS_EQUAL", "GREATER", "GREATER_EQUAL", "EQUAL", "NO_EQUAL"}:
        test_basic_3rd_arithmetic(context, args.rank, args.work_dir, args.operate_type)
    else:
        test_basic_arithmetic(context, args.rank, args.work_dir, args.operate_type)


def do_file_func(args):
    context = create_context(args)
    if args.n_part:
        test_basic_3rd_arithmetic_file(context, args.rank, args.work_dir, args.operate_type)
    else:
        test_basic_arithmetic_file(context, args.rank, args.work_dir, args.operate_type)


def main(argv=None):
    parser = create_parser()
    args = parser.parse_args(argv)
    if args.rank not in [0, 1, 2]:
        print("Error: --rank must be 0 or 1 or 2", file=sys.stderr)
        sys.exit(1)
    if args.mode == "memory":
        do_memory_func(args)
    else:
        do_file_func(args)


if __name__ == "__main__":
    main()