from dataclasses import dataclass, field
from typing import Optional, Any

import ray
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from minirl.utils.utils import get_master_addr_port


def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]):
    node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
    pg_ip = {}
    for pg in pgs:
        specs = ray._private.state.state.placement_group_table(pg.id)
        node_id = specs["bundles_to_node_id"][0]
        pg_ip[pg.id] = node_ip[node_id]
    return sorted(pgs, key=lambda pg: pg_ip[pg.id])


class RayResourcePool:
    def __init__(
        self,
        process_on_nodes: Optional[list[int]] = None,
        use_gpu: bool = True,
        name_prefix: str = None,
        max_colocate_count: int = 1,
    ) -> None:
        if process_on_nodes is None:
            process_on_nodes = []
        self._store = process_on_nodes
        self.max_colocate_count = max_colocate_count
        self.n_gpus_per_node = 8
        self.use_gpu = use_gpu
        self.name_prefix = "default_pool" if name_prefix is None else name_prefix

    @property
    def world_size(self):
        return sum(self._store)

    @property
    def store(self):
        return self._store

    def __call__(self):
        return self._store

    def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="npu"):
        pg_name_prefix = (
            name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
        )

        if device_name == "npu":
            device_name = "NPU"
        elif device_name == "cuda":
            device_name = "GPU"

        bundle = {"CPU": self.max_colocate_count}
        if self.use_gpu:
            bundle[device_name] = 1

        pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store]

        pgs = [
            placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx))
            for idx, bundles in enumerate(pg_scheme)
        ]

        ray.get([pg.ready() for pg in pgs])

        self.pgs = sort_placement_group_by_node_ip(pgs)
        return pgs


@dataclass
class ResourcePoolManager:
    resource_pool_spec: dict[str, list[int]]
    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)

    def create_resource_pool(self):
        """Create Ray resource pools for distributed training.

        Initializes resource pools based on the resource pool specification,
        with each pool managing GPU resources across multiple nodes.
        """
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            resource_pool = RayResourcePool(
                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
            )
            self.resource_pool_dict[resource_pool_name] = resource_pool


    def get_resource_pool(self, resource_pool_name: str) -> RayResourcePool:
        """Get the resource pool of the worker_cls"""
        return self.resource_pool_dict[resource_pool_name]


class RayClassWithInitArgs:
    def __init__(self, cls, *args, **kwargs) -> None:
        self.cls = cls
        self.args = args
        self.kwargs = kwargs
        self._options = {}

    def update_options(self, options: dict):
        """Update the Ray actor creation options.

        Args:
            options: Dictionary of options to update
        """
        self._options.update(options)

    def __call__(
        self,
        placement_group,
        placement_group_bundle_idx,
        use_gpu: bool = True,
        num_gpus=1,
        device_name="npu",
    ) -> Any:
        options = {
            "scheduling_strategy": PlacementGroupSchedulingStrategy(
                placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx
            )
        }
        options.update(self._options)

        if use_gpu and device_name == "cuda":
            options["num_gpus"] = num_gpus
        elif use_gpu and device_name == "npu":
            options["resources"] = {"NPU": num_gpus}
        else:
            raise ValueError(f"Unsupported device_name: {device_name}")

        return self.cls.options(**options).remote(*self.args, **self.kwargs)


def create_workers_from_pgs(worker_cls, resource_pool: RayResourcePool, name_prefix: str, device_name: str = "npu"):
    pgs = resource_pool.get_placement_groups(device_name=device_name)
    if not pgs:
        raise ValueError("No placement group created for rollout workers")

    master_addr, master_port = ray.get(
        get_master_addr_port.options(
            scheduling_strategy=PlacementGroupSchedulingStrategy(
                placement_group=pgs[0],
                placement_group_bundle_index=0,
            ),
        ).remote()
    )

    # create workers
    workers = []
    rank = -1
    for pg_idx, pg in enumerate(pgs):
        local_world_size = resource_pool.store[pg_idx]
        if local_world_size > pg.bundle_count:
            raise ValueError(
                f"Requested {local_world_size} workers in pg[{pg_idx}] "
                f"but only {pg.bundle_count} bundles exist."
            )
        for local_rank in range(local_world_size):
            rank += 1
            env_vars = {
                "WORLD_SIZE": str(resource_pool.world_size),
                "RANK": str(rank),
                "LOCAL_WORLD_SIZE": str(local_world_size),
                "LOCAL_RANK": str(local_rank),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": master_port,
                "WG_PREFIX": name_prefix,
                "WG_BACKEND": "ray",
                "RAY_LOCAL_WORLD_SIZE": str(local_world_size),
            }
            worker_cls.update_options(
                {
                    "runtime_env": {"env_vars": env_vars},
                    "name": f"{name_prefix}_{pg_idx}:{local_rank}",
                }
            )
            worker = worker_cls(
                placement_group=pg,
                placement_group_bundle_idx=local_rank,
                use_gpu=resource_pool.use_gpu,
                num_gpus=1 / resource_pool.max_colocate_count,
                device_name=device_name,
            )
            workers.append(worker)

    if len(workers) != resource_pool.world_size:
        raise RuntimeError(f"Worker count mismatch: {len(workers)} != {resource_pool.world_size}")

    return workers