import io
import os
import pickle
import re
import sys
import threading
from typing import Any, Optional, Union

from typing_extensions import TypeGuard

import torch
import torch_npu
from torch.serialization import (  # noqa: F401
    _check_dill_version,
    _default_to_weights_only,
    _get_storage_alignment,
    _is_torchscript_zip,
    _is_zipfile,
    _legacy_load,
    _load,
    _open_file_like,
    _open_zipfile_reader,
    _open_zipfile_writer,
    _serialization_tls,
    _weights_only_unpickler,
    DEFAULT_PROTOCOL,
    FileLike,
    location_tag,
    MAP_LOCATION,
    normalize_storage_type,
    UNSAFE_MESSAGE,
)
from torch_npu.utils._error_code import ErrCode, pta_error
from torch_npu._init.common.warning_utils import _should_print_warning


__all__ = ["load", "save_async"]

ALWAYS_WARN_LEGACY_SERIALIZATION = False
RE_MAP_CPU = False
save_async_stream_map = {}


def _get_always_warn_legacy_serialization():
    return ALWAYS_WARN_LEGACY_SERIALIZATION


def _set_always_warn_legacy_serialization(always_warn: bool):
    global ALWAYS_WARN_LEGACY_SERIALIZATION
    ALWAYS_WARN_LEGACY_SERIALIZATION = always_warn


def _warn_legacy_serialization(warn_massages, key_flag: str):
    def is_first_time(flag):
        warn_key = "has_warned_for" + flag if flag else None
        if not hasattr(_warn_legacy_serialization, warn_key):
            _warn_legacy_serialization.__dict__[warn_key] = True
            return True
        else:
            return not _warn_legacy_serialization.__dict__[warn_key]

    if _get_always_warn_legacy_serialization() or is_first_time(key_flag):
        if not _should_print_warning():
            return
        print(warn_massages)


def _remap_result(cpu_result, map_location):
    def traverse_dict(_dict) -> dict:
        for key, val in _dict.items():
            if isinstance(val, torch.Tensor):
                _dict[key] = val.to(map_location)
            elif isinstance(val, tuple):
                _dict[key] = traverse_tuple(val)
            elif isinstance(val, set):
                _dict[key] = traverse_set(val)
            elif isinstance(val, list):
                _dict[key] = traverse_list(val)
            elif isinstance(val, dict):
                _dict[key] = traverse_dict(val)
        return _dict

    def traverse_list(_list) -> list:
        for i, val in enumerate(_list):
            if isinstance(val, torch.Tensor):
                _list[i] = val.to(map_location)
            elif isinstance(val, tuple):
                _list[i] = traverse_tuple(val)
            elif isinstance(val, set):
                _list[i] = traverse_set(val)
            elif isinstance(val, list):
                _list[i] = traverse_list(val)
            elif isinstance(val, dict):
                _list[i] = traverse_dict(val)
        return _list

    def traverse_tuple(_tuple) -> tuple:
        new_list = []
        for val in _tuple:
            if isinstance(val, torch.Tensor):
                new_list.append(val.to(map_location))
            elif isinstance(val, tuple):
                new_list.append(traverse_tuple(val))
            elif isinstance(val, set):
                new_list.append(traverse_set(val))
            elif isinstance(val, list):
                new_list.append(traverse_list(val))
            elif isinstance(val, dict):
                new_list.append(traverse_dict(val))
            else:
                new_list.append(val)
        return tuple(new_list)

    def traverse_set(_set) -> set:
        new_list = []
        for val in iter(_set):
            if isinstance(val, torch.Tensor):
                new_list.append(val.to(map_location))
            elif isinstance(val, tuple):
                new_list.append(traverse_tuple(val))
            elif isinstance(val, set):
                new_list.append(traverse_set(val))
            elif isinstance(val, list):
                new_list.append(traverse_list(val))
            elif isinstance(val, dict):
                new_list.append(traverse_dict(val))
            else:
                new_list.append(val)
        return set(new_list)

    if isinstance(cpu_result, dict):
        return traverse_dict(cpu_result)
    elif isinstance(cpu_result, list):
        return traverse_list(cpu_result)
    elif isinstance(cpu_result, tuple):
        return traverse_tuple(cpu_result)
    elif isinstance(cpu_result, set):
        return traverse_set(cpu_result)
    elif isinstance(cpu_result, torch.Tensor):
        return cpu_result.to(map_location)
    else:
        return cpu_result


def _update_cpu_remap_info(map_location):
    global RE_MAP_CPU
    RE_MAP_CPU = False
    if isinstance(map_location, (str, torch.device)) and "cpu" in str(map_location):
        RE_MAP_CPU = True


def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
    return isinstance(name_or_buffer, (str, os.PathLike))


def load(
    f: FileLike,
    map_location: MAP_LOCATION = None,
    pickle_module: Any = None,
    *,
    weights_only: Optional[bool] = None,
    mmap: Optional[bool] = None,
    **pickle_load_args: Any,
) -> Any:
    _update_cpu_remap_info(map_location)
    torch._C._log_api_usage_once("torch.load")
    DOCS_MESSAGE = "\n\nCheck the documentation of torch.load to learn more about types accepted by default with weights_only."

    def _get_wo_message(message: str) -> str:
        unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
        has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
        blocklist_pattern = r"whose module (\S+) is blocked"
        has_blocklist = re.search(blocklist_pattern, message) is not None
        import_pattern = r"(\S+) must be (\S+) to load"
        has_import = re.search(import_pattern, message) is not None
        if has_unsafe_global:
            updated_message = (
                "Weights only load failed. This file can still be loaded, to do so you have two options, "
                "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
                f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
                "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
                + message
            )
        else:
            if has_import:
                return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n"
            else:
                updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
                if not has_blocklist:
                    updated_message += (
                        "Please file an issue with the following so that we can make "
                        "`weights_only=True` compatible with your use case: WeightsUnpickler error: "
                    )
            updated_message += message
        return updated_message + DOCS_MESSAGE

    weights_only_not_set = weights_only is None
    if weights_only_not_set:
        weights_only = _default_to_weights_only(pickle_module)

    true_values = ["1", "y", "yes", "true"]
    # Add ability to force safe only or non-safe weight loads via environment variables
    force_weights_only_load = (
        os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values
    )
    force_no_weights_only_load = (
        os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values
    )

    if force_weights_only_load and force_no_weights_only_load:
        raise RuntimeError(
            "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` "
            "should be set, but both were set." + pta_error(ErrCode.PARAM)
        )
    elif force_weights_only_load:
        weights_only = True
    elif force_no_weights_only_load:
        # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only
        if weights_only_not_set:
            print(
                "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
                "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False."
            )
            weights_only = False

    if weights_only:
        if pickle_module is not None:
            raise RuntimeError(
                "Can not safely load weights when explicit pickle_module is specified"
                + pta_error(ErrCode.PARAM)
            )
    else:
        if pickle_module is None:
            pickle_module = pickle

    # make flipping default BC-compatible
    if mmap is None:
        from torch.utils.serialization import config

        mmap = config.load.mmap

    _check_dill_version(pickle_module)

    if "encoding" not in pickle_load_args:
        pickle_load_args["encoding"] = "utf-8"

    with _open_file_like(f, "rb") as opened_file:
        if _is_zipfile(opened_file):
            orig_position = opened_file.tell()
            overall_storage = None
            with _open_zipfile_reader(opened_file) as opened_zipfile:
                if _is_torchscript_zip(opened_zipfile):
                    print(
                        "Warning: 'torch.load' received a zip file that looks like a TorchScript archive"
                        " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)"
                    )
                    if weights_only:
                        raise RuntimeError(
                            "Cannot use ``weights_only=True`` with TorchScript archives passed to "
                            "``torch.load``. "
                            + UNSAFE_MESSAGE
                            + pta_error(ErrCode.PARAM)
                        )
                    opened_file.seek(orig_position)
                    return torch.jit.load(opened_file, map_location=map_location)
                if mmap:
                    if not _is_path(f):
                        raise ValueError(
                            "f must be a file path in order to use the mmap argument"
                        )
                    size = os.path.getsize(f)
                    overall_storage = torch.UntypedStorage.from_file(
                        os.fspath(f), False, size
                    )
                if weights_only:
                    try:
                        return _load(
                            opened_zipfile,
                            map_location,
                            _weights_only_unpickler,
                            overall_storage=overall_storage,
                            **pickle_load_args,
                        )
                    except RuntimeError as e:
                        raise pickle.UnpicklingError(
                            _get_wo_message(str(e)) + pta_error(ErrCode.SYSCALL)
                        ) from None
                return _load(
                    opened_zipfile,
                    map_location,
                    pickle_module,
                    overall_storage=overall_storage,
                    **pickle_load_args,
                )
        else:
            if mmap:
                raise RuntimeError(
                    "mmap can only be used with files saved with "
                    "`torch.save(_use_new_zipfile_serialization=True), "
                    "please torch.save your checkpoint with this option in order to use mmap."
                    + pta_error(ErrCode.PARAM)
                )
            if weights_only:
                try:
                    return _legacy_load(
                        opened_file,
                        map_location,
                        _weights_only_unpickler,
                        **pickle_load_args,
                    )
                except RuntimeError as e:
                    raise pickle.UnpicklingError(
                        _get_wo_message(str(e)) + pta_error(ErrCode.SYSCALL)
                    ) from None

            warn_massage = (
                'Warning: since the loaded file is not a zipfile, only "torch.device" and "str" type parameters '
                "are currently supported for parameter types of map_location. If parameter types of map_location is "
                '"Callable[[torch.Tensor, str], torch.Tensor]" or "Dict[str, str]", which is only support for '
                "zipfile, all tensors are currently loaded onto the CPU, which may introduce problems."
            )
            _warn_legacy_serialization(warn_massage, "load")

            if map_location is not None and isinstance(
                map_location, (torch.device, str)
            ):
                cpu_result = _legacy_load(
                    opened_file, "cpu", pickle_module, **pickle_load_args
                )
                if isinstance(map_location, str) and "cpu" in map_location:
                    return cpu_result
                if (
                    isinstance(map_location, torch.device)
                    and "cpu" in map_location.type
                ):
                    return cpu_result
                return _remap_result(cpu_result, map_location)
            else:
                return _legacy_load(
                    opened_file, "cpu", pickle_module, **pickle_load_args
                )


def _npu_save(
    obj,
    zip_file,
    pickle_module,
    pickle_protocol,
    _disable_byteorder_record,
):
    serialized_storages = {}
    id_map: dict[int, str] = {}

    # Since loading storages that view the same data with different dtypes is
    # not supported, we need to keep track of the dtype associated with each
    # storage data_ptr and throw an error if the dtype is ever different.
    storage_dtypes: dict[int, torch.dtype] = {}

    def persistent_id(obj):
        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
            if isinstance(obj, torch.storage.TypedStorage):
                storage = obj._untyped_storage
                storage_dtype = obj.dtype
                storage_type_str = obj._pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                is_fake = hasattr(obj, "_fake_device") and obj._fake_device is not None
                is_meta = str(storage.device) == "meta"
                if storage.device.type == "cpu" or is_fake or is_meta:
                    storage_numel = obj._size()
                else:
                    storage_tensor = torch_npu._C._tensor_construct_from_storage(
                        storage
                    )
                    storage_numel = (
                        storage_tensor.size().numel()
                        * storage_tensor.element_size()
                        // obj._element_size()
                    )

            else:
                storage = obj
                storage_dtype = torch.uint8
                storage_type = normalize_storage_type(type(obj))
                is_meta = str(storage.device) == "meta"
                if storage.device.type == "cpu" or is_meta or storage.data_ptr() == 0:
                    storage_numel = storage.nbytes()
                else:
                    storage_tensor = torch_npu._C._tensor_construct_from_storage(
                        storage
                    )
                    storage_numel = (
                        storage_tensor.size().numel() * storage_tensor.element_size()
                    )

            # If storage is allocated, ensure that any other saved storages
            # pointing to the same data all have the same dtype. If storage is
            # not allocated, don't perform this check
            if str(storage.device) != "meta" and storage.data_ptr() != 0:
                if storage.data_ptr() in storage_dtypes:
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
                        raise RuntimeError(
                            "Cannot save multiple tensors or storages that "
                            "view the same data as different types"
                        )
                else:
                    storage_dtypes[storage.data_ptr()] = storage_dtype

            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
            if hasattr(obj, "_fake_device") and obj._fake_device is not None:
                location = str(obj._fake_device)
            else:
                location = location_tag(storage)
            serialized_storages[storage_key] = storage

            return ("storage", storage_type, storage_key, location, storage_numel)

        return None

    # Write the pickle data for `obj`
    data_buf = io.BytesIO()

    class PyTorchPickler(pickle_module.Pickler):  # type: ignore[name-defined]
        def persistent_id(self, obj):
            return persistent_id(obj)

    pickler = PyTorchPickler(data_buf, protocol=pickle_protocol)
    pickler.dump(obj)
    data_value = data_buf.getvalue()
    zip_file.write_record("data.pkl", data_value, len(data_value))
    # .format_version is used to track
    #     1. version 1 represents the order of storages being changed from
    #        lexicographical based on keys to numerically ordered based on keys
    #     2. version 2 represents including storage_alignment as a record
    #        within the zipfile
    zip_file.write_record(".format_version", "1", len("1"))
    storage_alignment = str(_get_storage_alignment())
    zip_file.write_record(
        ".storage_alignment", storage_alignment, len(storage_alignment)
    )

    # Write byte order marker
    if not _disable_byteorder_record:
        if sys.byteorder not in ["little", "big"]:
            raise ValueError("Unknown endianness type: " + sys.byteorder)

        zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))

    # Write each tensor to a file named tensor/the_tensor_key in the zip archive
    for key in serialized_storages:
        name = f"data/{key}"
        storage = serialized_storages[key]
        global _serialization_tls
        if _serialization_tls.skip_data:
            num_bytes = storage.nbytes() if hasattr(storage, "nbytes") else 0
            zip_file.write_record_metadata(name, num_bytes)
            continue
        if storage.device.type != "cpu":
            storage_tensor = torch_npu._C._tensor_construct_from_storage(storage)
            num_bytes = storage_tensor.size().numel() * storage_tensor.element_size()
        else:
            num_bytes = storage.nbytes()

        # given that we copy things around anyway, we might use storage.cpu()
        # this means to that to get tensors serialized, you need to implement
        # .cpu() on the underlying Storage
        if storage.device.type != "cpu":
            from torch.utils.serialization import config

            if (
                config.save.use_pinned_memory_for_d2h
                and (acc := torch.accelerator.current_accelerator(check_available=True))
                is not None
                and acc.type == storage.device.type
            ):
                new_storage = torch.empty(
                    num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
                ).untyped_storage()
                new_storage.copy_(storage)
                torch.accelerator.current_stream(storage.device.index).synchronize()
                storage = new_storage
            else:
                storage = storage.cpu()

        # Fallback for cases where tensor reduce path has already materialized
        # NPU tensors as CPU storages before reaching _npu_save.
        else:
            from torch.utils.serialization import config

            if (
                config.save.use_pinned_memory_for_d2h
                and (acc := torch.accelerator.current_accelerator(check_available=True))
                is not None
                and acc.type == "npu"
            ):
                new_storage = torch.empty(
                    num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
                ).untyped_storage()
                new_storage.copy_(storage)
                torch.accelerator.current_stream(storage.device.index).synchronize()
                storage = new_storage
        # Now that it is on the CPU we can directly copy it into the zip file
        zip_file.write_record(name, storage, num_bytes)


def save_async(
    obj: object,
    f,
    pickle_module: Any = pickle,
    pickle_protocol: int = DEFAULT_PROTOCOL,
    _use_new_zipfile_serialization: bool = True,
    _disable_byteorder_record: bool = False,
    model: torch.nn.Module = None,
) -> None:
    if _use_new_zipfile_serialization is False:
        raise RuntimeError(
            'Error: torch_npu.save_async with "_use_new_zipfile_serialization = False"'
            " is not recommended for npu tensor, which may bring unexpected errors and hopefully"
            ' set "_use_new_zipfile_serialization = True"',
            "if it is necessary to use this, please convert the npu tensor to cpu tensor for saving"
            + pta_error(ErrCode.PARAM),
        )

    _check_dill_version(pickle_module)
    save_args = (
        obj,
        f,
        pickle_module,
        pickle_protocol,
        _use_new_zipfile_serialization,
        _disable_byteorder_record,
    )

    device = torch.npu.current_device()
    save_thread = threading.Thread(
        target=_save_data_thread, args=(save_args, device, model)
    )
    save_thread.start()


def _save_data_thread(save_args, device, model: torch.nn.Module = None):
    global save_async_stream_map
    torch.npu.set_device(device)

    def hook_fn(*args):
        torch.npu.current_stream().wait_stream(save_async_stream_map.get(device))

    if device not in save_async_stream_map:
        save_async_stream = torch.npu.Stream()
        save_async_stream_map[device] = save_async_stream
        if isinstance(model, torch.nn.Module):
            model.register_full_backward_hook(hook_fn)
    else:
        save_async_stream = save_async_stream_map[device]

    (
        obj,
        f,
        pickle_module,
        pickle_protocol,
        _use_new_zipfile_serialization,
        _disable_byteorder_record,
    ) = save_args
    with torch.npu.stream(save_async_stream):
        data_value, serialized_storages = _save(obj, pickle_module, pickle_protocol)
        storage_value = []
        for key in sorted(serialized_storages.keys()):
            name = f"data/{key}"
            storage = serialized_storages.get(key)
            # given that we copy things around anyway, we might use storage.cpu()
            # this means to that to get tensors serialized, you need to implement
            # .cpu() on the underlying Storage
            if storage.device.type != "cpu":
                storage = storage.cpu()
            # Now that it is on the CPU we can directly copy it into the zip file
            if storage.device.type != "cpu":
                storage_tensor = torch_npu._C._tensor_construct_from_storage(storage)
                num_bytes = (
                    storage_tensor.size().numel() * storage_tensor.element_size()
                )
            else:
                num_bytes = storage.nbytes()
            storage_value.append((name, storage, num_bytes))

    with _open_zipfile_writer(f) as opened_zipfile:
        opened_zipfile.write_record("data.pkl", data_value, len(data_value))

        for name, storage, num_bytes in storage_value:
            opened_zipfile.write_record(name, storage.data_ptr(), num_bytes)


def _save(obj, pickle_module, pickle_protocol):
    serialized_storages = {}
    id_map: dict[int, str] = {}

    # Since loading storages that view the same data with different dtypes is
    # not supported, we need to keep track of the dtype associated with each
    # storage data_ptr and throw an error if the dtype is ever different.
    storage_dtypes: dict[int, torch.dtype] = {}

    def persistent_id(obj):
        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
            if isinstance(obj, torch.storage.TypedStorage):
                storage = obj._untyped_storage
                storage_dtype = obj.dtype
                storage_type_str = obj._pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                if storage.device.type != "cpu":
                    storage_tensor = torch_npu._C._tensor_construct_from_storage(
                        storage
                    )
                    storage_numel = (
                        storage_tensor.size().numel()
                        * storage_tensor.element_size()
                        // obj._element_size()
                    )
                else:
                    storage_numel = obj._size()

            else:
                storage = obj
                storage_dtype = torch.uint8
                storage_type = normalize_storage_type(type(obj))
                if storage.device.type != "cpu":
                    storage_tensor = torch_npu._C._tensor_construct_from_storage(
                        storage
                    )
                    storage_numel = (
                        storage_tensor.size().numel() * storage_tensor.element_size()
                    )
                else:
                    storage_numel = storage.nbytes()

            # If storage is allocated, ensure that any other saved storages
            # pointing to the same data all have the same dtype. If storage is
            # not allocated, don't perform this check
            if storage.data_ptr() != 0:
                if storage.data_ptr() in storage_dtypes:
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
                        raise RuntimeError(
                            "Cannot save multiple tensors or storages that "
                            "view the same data as different types"
                            + pta_error(ErrCode.VALUE)
                        )
                else:
                    storage_dtypes[storage.data_ptr()] = storage_dtype

            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
            location = location_tag(storage)
            serialized_storages[storage_key] = storage

            return ("storage", storage_type, storage_key, location, storage_numel)

        return None

    # Write the pickle data for `obj`
    data_buf = io.BytesIO()
    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    pickler.persistent_id = persistent_id
    if isinstance(obj, torch.nn.Module):
        hook_handle = obj._backward_hooks.copy()
        obj._backward_hooks.clear()
        pickler.dump(obj)
        obj._backward_hooks.update(hook_handle)
    else:
        pickler.dump(obj)
    data_value = data_buf.getvalue()
    return data_value, serialized_storages


def _add_serialization_methods():
    torch.serialization._save = _npu_save
    torch.load = load

    _orig_legacy_save = torch.serialization._legacy_save

    def _npu_legacy_save(obj, f, pickle_module, pickle_protocol):
        warn_massage = (
            'Warning: torch.save with "_use_new_zipfile_serialization = False" is not recommended '
            "for npu tensor, which may bring unexpected errors and hopefully set "
            '"_use_new_zipfile_serialization = True"',
            "if it is necessary to use this, please convert the npu tensor to cpu tensor for saving",
        )
        _warn_legacy_serialization(warn_massage, "save")
        return _orig_legacy_save(obj, f, pickle_module, pickle_protocol)

    torch.serialization._legacy_save = _npu_legacy_save

    torch.serialization.add_safe_globals([torch_npu.utils.storage._rebuild_npu_tensor])

    from torch_npu.npu._format import Format
    torch.serialization.add_safe_globals([Format])