# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Optional

import torch
from torch import Tensor

from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.serial_utils import batch_decode_from, batch_encode_into
from transfer_queue.utils.yuanrong_utils import find_reachable_host

logger = get_logger(__name__)


YUANRONG_DATASYSTEM_IMPORTED: bool = True

try:
    from yr import datasystem
except ImportError:
    YUANRONG_DATASYSTEM_IMPORTED = False


class StorageStrategy(ABC):
    """Abstract base class for storage strategies."""

    @staticmethod
    @abstractmethod
    def init(config: dict) -> Optional["StorageStrategy"]:
        """Initialize strategy from config; return None if not applicable."""

    @abstractmethod
    def strategy_tag(self) -> Any:
        """Return metadata identifying this strategy (e.g., string name, byte tag)."""

    @abstractmethod
    def supports_put(self, value: Any) -> bool:
        """Check if this strategy can store the given value."""

    @abstractmethod
    def put(self, keys: list[str], values: list[Any]) -> None:
        """Store key-value pairs using this strategy."""

    @abstractmethod
    def supports_get(self, strategy_tag: Any) -> bool:
        """Check if this strategy can retrieve data with given tag."""

    @abstractmethod
    def get(self, keys: list[str], **kwargs) -> list[Any | None]:
        """Retrieve values by keys; kwargs may include shapes/dtypes."""

    @abstractmethod
    def supports_clear(self, strategy_tag: Any) -> bool:
        """Check if this strategy owns data identified by metadata."""

    @abstractmethod
    def clear(self, keys: list[str]) -> None:
        """Delete keys from storage."""


class NPUTensorKVClientAdapter(StorageStrategy):
    """Adapter for YuanRong's high-performance NPU tensor storage.
    Using yr.datasystem.DsTensorClient to connect datasystem backends.
    """

    KEYS_LIMIT: int = 10_000

    def __init__(self, config: dict):
        port = config.get("worker_port")

        if port is None or not isinstance(port, int):
            raise ValueError("Missing or invalid 'worker_port' in config")

        logger.info(f"Auto-detecting reachable host for Yuanrong port {port}...")
        host = find_reachable_host(port)
        if host is None:
            raise ValueError(
                f"Could not find any reachable host for Yuanrong port {port}. "
                "Please ensure yuanrong datasystem is running."
            )
        logger.info(f"Using auto-detected host: {host}")

        self.device_id = torch.npu.current_device()
        torch.npu.set_device(self.device_id)

        self._ds_client = datasystem.DsTensorClient(host, port, self.device_id)
        self._ds_client.init()
        logger.info("YuanrongStorageClient: Create DsTensorClient to connect with yuanrong-datasystem backend!")

    @staticmethod
    def init(config: dict) -> Optional["StorageStrategy"]:
        """Initialize only if NPU and torch_npu are available."""
        torch_npu_imported: bool = True
        try:
            import torch_npu  # noqa: F401
        except ImportError:
            torch_npu_imported = False
        enable = config.get("enable_yr_npu_transport", True)
        if not (enable and torch_npu_imported and torch.npu.is_available()):
            return None

        return NPUTensorKVClientAdapter(config)

    def strategy_tag(self) -> str:
        """Strategy tag for NPU tensor storage. Using a single byte is for better performance."""
        return "1"

    def supports_put(self, value: Any) -> bool:
        """Supports contiguous NPU tensors only."""
        if not (isinstance(value, torch.Tensor) and value.device.type == "npu"):
            return False
        # Only contiguous NPU tensors are supported by this adapter.
        return value.is_contiguous()

    def put(self, keys: list[str], values: list[Any]) -> None:
        """Store NPU tensors in batches; deletes before overwrite."""
        for i in range(0, len(keys), self.KEYS_LIMIT):
            batch_keys = keys[i : i + self.KEYS_LIMIT]
            batch_values = values[i : i + self.KEYS_LIMIT]
            # mset_d2h cannot overwrite existing keys
            try:
                self._ds_client.delete(batch_keys)
            except Exception:
                pass
            self._ds_client.mset_d2h(batch_keys, batch_values)

    def supports_get(self, strategy_tag: str) -> bool:
        """Matches 'DsTensorClient' Strategy tag."""
        return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()

    def get(self, keys: list[str], **kwargs) -> list[Any | None]:
        """Fetch NPU tensors using pre-allocated empty buffers."""
        shapes = kwargs.get("shapes", None)
        dtypes = kwargs.get("dtypes", None)
        if shapes is None or dtypes is None:
            raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes")
        results = []
        for i in range(0, len(keys), self.KEYS_LIMIT):
            batch_keys = keys[i : i + self.KEYS_LIMIT]
            batch_shapes = shapes[i : i + self.KEYS_LIMIT]
            batch_dtypes = dtypes[i : i + self.KEYS_LIMIT]

            batch_values = self._create_empty_npu_tensorlist(batch_shapes, batch_dtypes)
            self._ds_client.mget_h2d(batch_keys, batch_values)
            # Todo(dpj): consider checking and logging keys that fail during mget_h2d
            results.extend(batch_values)
        return results

    def supports_clear(self, strategy_tag: str) -> bool:
        """Matches 'DsTensorClient' strategy tag."""
        return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()

    def clear(self, keys: list[str]) -> None:
        """Delete NPU tensor keys in batches."""
        for i in range(0, len(keys), self.KEYS_LIMIT):
            batch = keys[i : i + self.KEYS_LIMIT]
            # Todo(dpj): Test call clear when no (key,value) put in ds
            self._ds_client.delete(batch)

    def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) -> list[Tensor]:
        """
        Create a list of empty NPU tensors with given shapes and dtypes.

        Args:
            shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)])
            dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64])
        Returns:
            list[Tensor]: List of uninitialized NPU tensors
        """
        tensors: list[Tensor] = []
        for shape, dtype in zip(shapes, dtypes, strict=True):
            tensor = torch.empty(shape, dtype=dtype, device=f"npu:{self.device_id}")
            tensors.append(tensor)
        return tensors


class GeneralKVClientAdapter(StorageStrategy):
    """Adapter for general-purpose KV storage with serialization.
    Using yr.datasystem.KVClient to connect datasystem backends.
    The serialization method uses 'batch_encode_into' and 'batch_decode_from' from 'transfer_queue.utils.serial_utils'.
    """

    PUT_KEYS_LIMIT: int = 10_000
    GET_CLEAR_KEYS_LIMIT: int = 10_000
    DS_MAX_WORKERS: int = 16

    def __init__(self, config: dict):
        port = config.get("worker_port")

        if port is None or not isinstance(port, int):
            raise ValueError("Missing or invalid 'worker_port' in config")

        logger.info(f"Auto-detecting reachable host for Yuanrong port {port}...")
        host = find_reachable_host(port)
        if host is None:
            raise ValueError(
                f"Could not find any reachable host for Yuanrong port {port}. "
                "Please ensure yuanrong datasystem is running."
            )
        logger.info(f"Using auto-detected host: {host}")

        self._ds_client = datasystem.KVClient(host, port)
        self._ds_client.init()
        logger.info("YuanrongStorageClient: Create KVClient to connect with yuanrong-datasystem backend!")

    @staticmethod
    def init(config: dict) -> Optional["StorageStrategy"]:
        """Always enabled for general objects."""
        return GeneralKVClientAdapter(config)

    def strategy_tag(self) -> str:
        """Strategy tag for general KV storage. Using a single byte is for better performance."""
        return "2"

    def supports_put(self, value: Any) -> bool:
        """Accepts any Python object."""
        return True

    def put(self, keys: list[str], values: list[Any]) -> None:
        """Store objects via zero-copy serialization in batches."""
        for i in range(0, len(keys), self.PUT_KEYS_LIMIT):
            batch_keys = keys[i : i + self.PUT_KEYS_LIMIT]
            batch_vals = values[i : i + self.PUT_KEYS_LIMIT]
            self.mset_zero_copy(batch_keys, batch_vals)

    def supports_get(self, strategy_tag: str) -> bool:
        """Matches 'KVClient' strategy tag."""
        return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()

    def get(self, keys: list[str], **kwargs) -> list[Any | None]:
        """Retrieve and deserialize objects in batches."""
        results = []
        for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
            batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
            objects = self.mget_zero_copy(batch_keys)
            results.extend(objects)
        return results

    def supports_clear(self, strategy_tag: str) -> bool:
        """Matches 'KVClient' strategy tag."""
        return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()

    def clear(self, keys: list[str]) -> None:
        """Delete keys in batches."""
        for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
            batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
            self._ds_client.delete(batch_keys)

    def mset_zero_copy(self, keys: list[str], objs: list[Any]):
        """Store multiple objects in zero-copy mode using parallel serialization and buffer packing.

        Args:
            keys (list[str]): List of string keys under which the objects will be stored.
            objs (list[Any]): List of Python objects to store (e.g., tensors, strings).
        """
        buffers: list = []

        def alloc(sizes):
            # DataSystem buffers must be converted via MutableData() to obtain
            # a memoryview-compatible data structure for zero-copy packing.
            mcreate_bufs = self._ds_client.mcreate(keys, sizes)
            buffers.extend(mcreate_bufs)
            return [buf.MutableData() for buf in mcreate_bufs]

        batch_encode_into(objs, alloc, num_workers=self.DS_MAX_WORKERS)
        self._ds_client.mset_buffer(buffers)

    def mget_zero_copy(self, keys: list[str]) -> list[Any]:
        """Retrieve multiple objects in zero-copy mode by directly deserializing from shared memory buffers.

        Args:
            keys (list[str]): List of string keys to retrieve from storage.

        Returns:
            list[Any]: List of deserialized objects corresponding to the input keys.
        """
        buffers = self._ds_client.get_buffers(keys)
        valid_indexes = [i for i, buf in enumerate(buffers) if buf is not None]
        if valid_indexes and len(valid_indexes) < len(keys):
            logger.warning(
                f"{len(keys) - len(valid_indexes)} requested keys were not found in openYuanrong-datasystem storage. "
                f"Returned results will contain None for these keys."
            )
        valid_bufs = [buffers[i] for i in valid_indexes]
        decoded_objs = batch_decode_from(valid_bufs)
        results = [None] * len(keys)
        for idx, obj in zip(valid_indexes, decoded_objs, strict=True):
            results[idx] = obj
        return results


@StorageClientFactory.register("YuanrongStorageClient")
class YuanrongStorageClient(StorageKVClient):
    """
    Storage client for YuanRong DataSystem.

    Use different storage strategies depending on the data type.
    Supports storing and fetching both:
    - NPU tensors via NPUTensorKVClientAdapter (for high performance).
    - General objects (CPU tensors, str, bool, list, etc.) via GeneralKVClientAdapter with serialization.
    """

    ROUTE_ITEM_AS_VALUE = "value"
    ROUTE_ITEM_AS_BACKEND_META = "backend_meta"

    def __init__(self, config: dict[str, Any]):
        if not YUANRONG_DATASYSTEM_IMPORTED:
            raise ImportError("YuanRong DataSystem not installed.")

        port = config.get("worker_port")

        if port is None or not isinstance(port, int):
            raise ValueError("Missing or invalid 'worker_port' in config")

        super().__init__(config)

        # Storage strategies are prioritized in ascending order of list element index.
        # In other words, the later in the order, the lower the priority.
        storage_strategies_priority = [NPUTensorKVClientAdapter, GeneralKVClientAdapter]
        self._strategies: list[StorageStrategy] = []
        for strategy_cls in storage_strategies_priority:
            strategy = strategy_cls.init(config)
            if strategy is not None:
                self._strategies.append(strategy)

        if not self._strategies:
            raise RuntimeError("No storage backend available for YuanrongStorageClient")

    def put(self, keys: list[str], values: list[Any]) -> list[str]:
        """Stores multiple key-value pairs to remote storage.

        Automatically routes NPU tensors to high-performance tensor storage,
        and other objects to general-purpose KV storage.

        Args:
            keys (List[str]): List of unique string identifiers.
            values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).

        Returns:
            List[str]: storage strategy tag of YuanrongStorageClient in the same order as input keys.
        """
        if not isinstance(keys, list) or not isinstance(values, list):
            raise ValueError("keys and values must be lists")
        if len(keys) != len(values):
            raise ValueError("Number of keys must match number of values")

        routed_indexes = self._route_to_strategies(
            values, lambda strategy_, item_: strategy_.supports_put(item_), item_label=self.ROUTE_ITEM_AS_VALUE
        )

        # Define the 'put_task': Slicing the input list and calling the backend strategy.
        # The closure captures local 'keys' and 'values' for zero-overhead parameter passing.
        def put_task(strategy, indexes):
            strategy.put([keys[i] for i in indexes], [values[i] for i in indexes])
            return strategy.strategy_tag(), indexes

        # Dispatch tasks and map strategy_tag back to original positions
        strategy_tags: list[str] = [""] * len(keys)
        for tag, indexes in self._dispatch_tasks(routed_indexes, put_task):
            for original_index in indexes:
                strategy_tags[original_index] = tag
        return strategy_tags

    def get(
        self,
        keys: list[str],
        shapes: list[Any] | None = None,
        dtypes: list[Any] | None = None,
        custom_backend_meta: list[str] | None = None,
    ) -> list[Any]:
        """Retrieves multiple values from remote storage with expected metadata.

        Requires shape and dtype hints to reconstruct NPU tensors correctly.

        Args:
            keys (List[str]): Keys to fetch.
            shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
            dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
            custom_backend_meta (List[str]): StorageStrategy tag for each key

        Returns:
            List[Any]: Retrieved values in the same order as input keys.
        """
        if shapes is None or dtypes is None or custom_backend_meta is None:
            raise ValueError("YuanrongStorageClient.get() needs Expected shapes, dtypes and custom_backend_meta")

        if not (len(keys) == len(shapes) == len(dtypes) == len(custom_backend_meta)):
            raise ValueError("Lengths of keys, shapes, dtypes, custom_backend_meta must match")

        if any(not tag for tag in custom_backend_meta):
            raise ValueError(
                "Some keys have no backend metadata (empty string), indicating they "
                "were not previously stored. Ensure all keys have been put before calling get."
            )

        strategy_tags = custom_backend_meta
        routed_indexes = self._route_to_strategies(
            strategy_tags,
            lambda strategy_, item_: strategy_.supports_get(item_),
            item_label=self.ROUTE_ITEM_AS_BACKEND_META,
        )

        # Define the 'get_task': handles slicing of keys, shapes, and dtypes simultaneously.
        def get_task(strategy, indexes):
            res = strategy.get(
                [keys[i] for i in indexes], shapes=[shapes[i] for i in indexes], dtypes=[dtypes[i] for i in indexes]
            )
            return res, indexes

        # Gather results and restore original order
        results = [None] * len(keys)
        for strategy_res, indexes in self._dispatch_tasks(routed_indexes, get_task):
            for value, original_index in zip(strategy_res, indexes, strict=True):
                results[original_index] = value
        return results

    def clear(self, keys: list[str], custom_backend_meta: list[str] | None = None) -> None:
        """Deletes multiple keys from remote storage.

        Args:
            keys (List[str]): List of keys to remove.
            custom_backend_meta (List[str]): StorageStrategy tag for each key
        """
        if not isinstance(keys, list) or not isinstance(custom_backend_meta, list):
            raise ValueError("keys and custom_backend_meta must be a list")

        if len(custom_backend_meta) != len(keys):
            raise ValueError("custom_backend_meta length must match keys")

        strategy_tags = custom_backend_meta
        routed_indexes = self._route_to_strategies(
            strategy_tags,
            lambda strategy_, item_: strategy_.supports_clear(item_),
            ignore_unmatched=True,
            item_label=self.ROUTE_ITEM_AS_BACKEND_META,
        )

        def clear_task(strategy, indexes):
            strategy.clear([keys[i] for i in indexes])

        # Execute deletions (no return values needed)
        self._dispatch_tasks(routed_indexes, clear_task)

    def _route_to_strategies(
        self,
        items: list[Any],
        selector: Callable[[StorageStrategy, Any], bool],
        *,
        ignore_unmatched: bool = False,
        item_label: str,
    ) -> dict[StorageStrategy, list[int]]:
        """Groups item indices by the first strategy that supports them.

        Used to route data to storage strategies by grouped indexes.

        Args:
            items: A list used to distinguish which storage strategy the data is routed to.
                   e.g., route <keys, values> for put based on types of values,
                   or route <keys, Optional[shapes], Optional[dtypes]> for get/clear based on strategy_tags.
                   The order must correspond to the original keys.
            selector: A function that determines whether a strategy supports an item.
                     Signature: `(strategy: StorageStrategy, item: Any) -> bool`.
            ignore_unmatched: If True, items that don't match any strategy will be ignored (not included in output).
                              If False, a ValueError will be raised for any unmatched item.
            item_label: Description of what `items` represents, used in error messages.
                        Use ROUTE_ITEM_AS_VALUE for put (user-provided data),
                        or ROUTE_ITEM_AS_BACKEND_META for get/clear (backend metadata).

        Returns:
            A dictionary mapping each active strategy to a list of indexes in `items`
            that it should handle. Every index appears exactly once.
        """
        unmatched_count = 0
        warning_count = 0
        routed_indexes: dict[StorageStrategy, list[int]] = {s: [] for s in self._strategies}
        for i, item in enumerate(items):
            for strategy in self._strategies:
                if selector(strategy, item):
                    routed_indexes[strategy].append(i)
                    break
            else:
                if ignore_unmatched:
                    if item:  # non-empty item → real tag, backend likely unavailable
                        warning_count += 1
                    unmatched_count += 1
                else:
                    if item_label == self.ROUTE_ITEM_AS_BACKEND_META:
                        raise ValueError(
                            "Cannot retrieve stored data because the backend that originally "
                            "stored it is unavailable in the current process or node. Please "
                            "check that the configuration and NPU resource availability are "
                            "consistent across all processes and nodes."
                        )
                    else:
                        raise ValueError(f"No storage backend can handle {item_label} of type {type(item).__name__}.")
        if warning_count > 0:
            logger.warning(
                f"{warning_count} stored items could not be processed because the backend "
                f"that originally handled them may be unavailable in the current process or "
                f"node. Please check that the configuration and NPU resource availability "
                f"are consistent across all processes and nodes."
            )
        if unmatched_count > warning_count:
            logger.debug(
                f"{unmatched_count - warning_count} items with empty {item_label} "
                f"will be silently skipped (likely not previously stored)."
            )

        return routed_indexes

    @staticmethod
    def _dispatch_tasks(routed_tasks: dict[StorageStrategy, list[int]], task_function: Callable) -> list[Any]:
        """Executes tasks across one or more storage strategies, optionally in parallel.

        Optimizes for common case: if only one strategy is involved, runs synchronously
        to avoid thread overhead. Otherwise, uses thread pool for concurrency.

        Args:
            routed_tasks: Mapping from strategy to list of indexes it should process.
            task_function: Callable accepting `(strategy, list_of_indexes)` and returning any result.

        Returns:
            List of results from `task_function`, one per active strategy, in arbitrary order.
            Each result typically includes data and the corresponding indices for reassembly.
        """
        active_tasks = [(strategy, indexes) for strategy, indexes in routed_tasks.items() if indexes]

        if not active_tasks:
            return []

        # Fast path: single strategy → avoid threading
        if len(active_tasks) == 1:
            return [task_function(*active_tasks[0])]

        # Parallel path: overlap NPU and CPU operations
        # Cap the number of worker threads to avoid resource exhaustion if many
        # strategies are added in the future.
        max_workers = min(len(active_tasks), 4)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # futures' results are from task_function
            futures = [executor.submit(task_function, strategy, indexes) for strategy, indexes in active_tasks]
            return [f.result() for f in futures]