# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

from . import libkcal


def create_link_desc(d: dict) -> libkcal.LinkDesc:
    """Create yacl LinkDesc configuration.

    Supports two input formats:

    1. Unified configuration format (recommended, with the same configuration used on both ends):
        {
            "id": "kcals_psi_test",  # 可选, 默认 "root"
            "nodes": [
                {"party": "alice", "address": "127.0.0.1:41929"},
                {"party": "bob", "address": "127.0.0.1:56815"}
            ],
            # 可选参数
            "connect_retry_times": 3,
            "connect_retry_interval_ms": 1000,
            "recv_timeout_ms": 60000,
        }

    2. Parties format (compatible with older versions):
        {
            "id": "kcals_psi_test",  # 可选
            "parties": [
                {"id": "0", "host": "127.0.0.1:5001"},
                {"id": "1", "host": "127.0.0.1:5002"}
            ],
            ...
        }

    Note:
        - In the unified configuration format, the order of nodes determines the rank of the party (the rank of the party with index 0 is 0).
        - If the same configuration is used on both ends, the underlying layer automatically establishes the correct connection based on self_rank.
        - Rank conversion is determined by the node order on the Python side.
    """

    desc = libkcal.LinkDesc()
    desc.id = d.get("id", "root")

    if "nodes" in d:
        nodes_list = d["nodes"]
        if not isinstance(nodes_list, (list, tuple)):
            raise ValueError("'nodes' must be a list or tuple")

        for i, node in enumerate(nodes_list):
            if not isinstance(node, dict):
                raise ValueError(f"nodes[{i}] must be a dict")
            if "party" not in node or "address" not in node:
                raise ValueError(f"nodes[{i}] must contain 'party' and 'address'")
            desc.add_party(str(i), node["address"])
    elif "parties" in d:
        parties_list = d["parties"]
        if not isinstance(parties_list, (list, tuple)):
            raise ValueError("'parties' must be a list or tuple")
        for i, p in enumerate(parties_list):
            if not isinstance(p, dict):
                raise ValueError(f"parties[{i}] must be a dict")
            if "id" not in p or "host" not in p:
                raise ValueError(f"parties[{i}] must contain 'id' and 'host'")
            desc.add_party(p["id"], p["host"])
    else:
        raise ValueError("Missing required field: 'nodes' or 'parties'")

    desc.connect_retry_times = d.get("connect_retry_times", 3)
    desc.connect_retry_interval_ms = d.get("connect_retry_interval_ms", 1000)
    desc.recv_timeout_ms = d.get("recv_timeout_ms", 60000)
    desc.http_max_payload_size = d.get("http_max_payload_size", 0)
    desc.http_timeout_ms = d.get("http_timeout_ms", 0)
    desc.throttle_window_size = d.get("throttle_window_size", 0)
    desc.link_type = d.get("link_type", "normal")

    return desc


def create_link_from_nodes(
    nodes: list, link_id: str = "kcal_link", **kwargs
) -> libkcal.LinkDesc:
    """Creates link configurations from the nodes list.

    Args:
        nodes: Node list, e.g., [{"party": "alice", "address": "127.0.0.1:41929"}, ...]
               The list order determines the rank. The rank of the first node is 0, the rank of the second node is 1, and so on.
        link_id: Link identifier.
        **kwargs: Other optional parameters (such as connect_retry_times and recv_timeout_ms) are supported.

    Returns:
        LinkDesc object

    Example:
        # Both ends use exactly the same configuration.
        nodes = [
            {"party": "alice", "address": "127.0.0.1:41929"},
            {"party": "bob", "address": "127.0.0.1:56815"}
        ]

        # rank 0
        desc = create_link_from_nodes(nodes, "test_link")
        ctx = kcal.Context.create_with_link_config(config, desc, rank=0)

        # rank 1
        # Use the same desc.
        ctx = kcal.Context.create_with_link_config(config, desc, rank=1)
    """
    d = {
        "id": link_id,
        "nodes": nodes,
    }
    d.update(kwargs)
    return create_link_desc(d)