# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

"""
KCAL Python bindings.
"""
from __future__ import annotations
import collections.abc
import enum
from typing import List, Any, Optional
from dataclasses import dataclass
from abc import ABC, abstractmethod

__all__: list[str] = [
    'AlgorithmsType',
    'TeeMode',
    'ShareType',
    'DummyMode',
    'Config',
    'Party',
    'LinkDesc',
    'Context',
    'Psi',
    'create_psi',
    'PsiUb',
    'create_psi_ub',
    'Pir',
    'create_pir',
    'MpcShare',
    'MpcShareSet',
    'Input',
    'MakeShare',
    'create_make_share',
    'RevealShare',
    'create_reveal_share',
    'MpcOperatorBase',
    'create_mpc',
]


class AlgorithmsType(enum.IntEnum):
    PSI = 0
    PIR = 1
    ARITHMETIC = 2
    MAKE_SHARE = 3
    REVEAL_SHARE = 4
    ADD = 5
    SUB = 6
    MUL = 7
    DIV = 8
    LESS = 9
    LESS_EQUAL = 10
    GREATER = 11
    GREATER_EQUAL = 12
    EQUAL = 13
    NO_EQUAL = 14
    SUM = 15
    AVG = 16
    MAX = 17
    MIN = 18
    PSI_UB = 19
    ASCEND_SORT = 20
    DESCEND_SORT = 21


class TeeMode(enum.IntEnum):
    """TEE output mode - maps to DG_TeeMode"""
    OUTPUT_STRING = 0  # TEE_OUTPUT_STRING
    OUTPUT_INDEX = 1  # TEE_OUTPUT_INDEX


class ShareType(enum.IntEnum):
    FIX_POINT = 0
    NON_FIX_POINT = 1


class DummyMode(enum.IntEnum):
    NORMAL = 0
    DUMMY = 1


@dataclass
class Config:
    useSMAlg: bool = False
    fixBits: int = 2
    nodeId: int = 0
    threadCount: int = 16
    worldSize: int = 2
    chunkSize: int = 100000000
    bucketCount: int = 256
    tmpPath: str = ""


class Party:
    """Represents a party in yacl link context."""
    id: str
    host: str

    def __init__(self, id: str = "", host: str = "") -> None: ...


class LinkDesc:
    """yacl link context descriptor."""
    id: str
    parties: List[Party]
    connect_retry_times: int
    connect_retry_interval_ms: int
    recv_timeout_ms: int
    http_max_payload_size: int
    http_timeout_ms: int
    throttle_window_size: int
    link_type: str

    def __init__(self) -> None: ...

    def add_party(self, id: str, host: str) -> None: ...


class Context:
    @staticmethod
    def create(config: Config, send_func: collections.abc.Callable, recv_func: collections.abc.Callable) -> Context: ...

    @staticmethod
    def create(config: Config, send_func: collections.abc.Callable, recv_func: collections.abc.Callable,
               read_input_cb:collections.abc.Callable = None,
               read_pair_list_cb:collections.abc.Callable = None,
               write_output_cb:collections.abc.Callable = None) -> Context: ...

    @staticmethod
    def create_with_yacl(config: Config, desc: LinkDesc, rank: int, has_file: bool) -> Context:
        """Create context using yacl link (without connecting to mesh)."""

    @staticmethod
    def create_with_link_config(config: Config, desc: LinkDesc, rank: int, has_file: bool, log_details: bool = False) -> Context:
        """Create context with link config (automatically connects to mesh)."""

    def get_yacl_context(self) -> Any:
        """Get the underlying yacl link context."""


class Psi:
    def __init__(self, ctx: Context) -> None: ...

    def run(self, input: List[str], output: List[Any], tee_mode: int) -> int:
        """
        Run PSI operation.

        Args:
            input: Input data list
            output: Output result list
            tee_mode: TeeMode as int (use kcal.TeeMode.OUTPUT_STRING or .OUTPUT_INDEX)
        """


def create_psi(ctx: Context) -> Psi: ...


class PsiUb:
    def __init__(self, ctx: Context) -> None: ...

    def run(self, inputFilePath: str, outputFilePath: str) -> tuple[int, int]:
        """
        Run PsiUb operation.

        Args:
            inputFilePath: Input file path
            outputFilePath: Output file path
        """


def create_psi_ub(ctx: Context) -> PsiUb: ...


class Pir:
    def __init__(self, ctx: Context) -> None: ...

    def ServerPreProcess(self, keys: List[str], values: List[str]) -> int: ...

    def ClientQuery(self, input: List[str], output: List[Any], dummy_mode: int) -> int:
        """
        Client query for PIR.

        Args:
            input: Input data list
            output: Output result list
            dummy_mode: DummyMode as int (use kcal.DummyMode.NORMAL or .DUMMY)
        """

    def ServerAnswer(self) -> int: ...

    def ServerPreProcess(self, input:str, outputPath: str) -> int: ...

    def ServerAnswer(self, dataPath: str, isDeleteCache: int) -> int: ...


def create_pir(ctx: Context) -> Pir: ...


class MpcShare:
    def __init__(self) -> None: ...

    def size(self) -> int: ...

    def type(self) -> ShareType: ...


class MpcShareSet:
    @staticmethod
    def Create(shares: List[MpcShare]) -> MpcShareSet: ...

    def Get(self) -> Any: ...


class Input:
    """Input/Output class for MPC operations."""
    def __init__(self) -> None: ...

    @staticmethod
    def create() -> Input: ...

    def Set(self, data: Any) -> None: ...

    def Fill(self, data: Any) -> None: ...

    def Size(self) -> int: ...


class MakeShare:
    def __init__(self, ctx: Context) -> None: ...

    def run(self, input: List[Any], is_recv_share: int, out_share: MpcShare) -> int: ...

    def run(self, inputFilePath:str, isRecvShare: int, shareFilePath: str) -> tuple[int, int]:


def create_make_share(ctx: Context) -> MakeShare: ...


class RevealShare:
    def __init__(self, ctx: Context) -> None: ...

    def run(self, input_share: MpcShare, output: List[Any]) -> int: ...

    def run(self, shareFilePath:str, outputFilePath: str) -> tuple[int, int]: ...


def create_reveal_share(ctx: Context) -> RevealShare: ...


class MpcOperatorBase(ABC):
    def GetType(self) -> AlgorithmsType: ...

    @abstractmethod
    def run(self, shares: List[MpcShare], out_share: MpcShare) -> int: ...

    def run(self, input_file_paths:List[str], output_file_path: str) -> tuple[int, int]: ...


def create_mpc(ctx: Context, type: int) -> MpcOperatorBase:
    """
    Create MPC operator.

    Args:
        ctx: KCAL context
        type: AlgorithmsType as int (use kcal.AlgorithmsType.ADD, .MUL, etc.)
    """