import os
import logging
import gc
from glob import glob
from typing import Optional
from dataclasses import dataclass
from tqdm import tqdm

import torch
from torch.distributed.tensor import DTensor

from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.checkpoint.dcp_utils import load_metadata, extract_metadata, partial_load_dcp_state_dict
from mindspeed_mm.fsdp.utils.utils import to_empty_if_needed, tensor_to_dtensor
from mindspeed_mm.fsdp.utils.device import get_device_type, empty_cache


logger = logging.getLogger(__name__)


@dataclass
class ParamInfo:
    """
    Metadata for broadcasting rank 0 checkpoint to all ranks.
    """
    name: Optional[str] = None
    shape: Optional[torch.Size] = None
    dtype: Optional[torch.dtype] = None
    prefix: Optional[str] = None


def chunk_list(lst, chunk_size):
    """Yield successive chunk_size-sized chunks from lst."""
    k, m = divmod(len(lst), chunk_size)

    return [lst[i * k + min(i, m): (i + 1) * k + min(i + 1, m)]
            for i in range(chunk_size)]


@torch.no_grad()
def rank0_load_and_broadcast_weights(load_state, storage_reader):
    MODEL = "model"
    OPTIMIZER = "optimizer"

    model = load_state[MODEL].model

    model_state_dict = load_state[MODEL].state_dict()
    params_to_load = set(model_state_dict.keys())
    dcp_keys = [f"{MODEL}.{key}" for key in model_state_dict.keys()]

    if OPTIMIZER in load_state:
        optim_state_dict = load_state[OPTIMIZER].state_dict()
        params_to_load.update(optim_state_dict.keys())
        dcp_keys.extend([f"{OPTIMIZER}.{key}" for key in optim_state_dict.keys()])
    else:
        optim_state_dict = None

    torch_device = torch.device(get_device_type())
    global_rank = torch.distributed.get_rank()

    shard_info_list = []
    if global_rank == 0:
        metadata = load_metadata(storage_reader)

        fqn2file = {}
        for key, value in metadata.storage_data.items():
            fqn = key.fqn
            if fqn not in fqn2file:
                fqn2file[fqn] = set()
            fqn2file[fqn].add(value.relative_path)

        shard_dict = {k: v for k, v in fqn2file.items() if len(v) > 1}
        unshard_dict = {k: v for k, v in fqn2file.items() if len(v) == 1 and not k.startswith(f"{OPTIMIZER}.")}
        optim_unshard_dict = {k: v for k, v in fqn2file.items() if len(v) == 1 and k.startswith(f"{OPTIMIZER}.")}

        shard_info_list = []
        if len(optim_unshard_dict) > 0:
            fqn2file_list = sorted(optim_unshard_dict.items(), key=lambda x: x[0])
            selected_keys = [fqn[0] for fqn in fqn2file_list if fqn[0] in dcp_keys]
            shard_info_list.append(selected_keys)

        def register_shard_info(info_dict):
            nonlocal dcp_keys
            nonlocal shard_info_list
            file2fqn = {}
            for key, value in info_dict.items():
                files_tuple = tuple(sorted(value))
                if files_tuple not in file2fqn:
                    file2fqn[files_tuple] = set()
                file2fqn[files_tuple].add(key)
            file2fqn_list = sorted(file2fqn.items(), key=lambda x: x[0])
            for files_tuple, fqn_set in file2fqn_list:
                selected_keys = [fqn for fqn in fqn_set if fqn in dcp_keys]
                shard_info_list.append(selected_keys)
            return len(file2fqn_list)

        if len(shard_dict) == 0:
            register_shard_info(unshard_dict)
        else:
            shard_info_count = register_shard_info(shard_dict)
            if shard_info_count == 1:
                shard_info_list = shard_info_list[:-shard_info_count]
            if shard_info_count <= 1:
                fqn2file_list = sorted(shard_dict.items(), key=lambda x: x[0])
                shard_size = len(fqn2file_list[0][1])
                file_num = len(glob(os.path.join(storage_reader.path, "*.distcp")))
                shard_count = max(file_num // shard_size, 1)
                for fqn2file_elem in chunk_list(fqn2file_list, shard_count):
                    selected_keys = [fqn[0] for fqn in fqn2file_elem if fqn[0] in dcp_keys]
                    shard_info_list.append(selected_keys)
    else:
        shard_info_list = []

    shard_count = len(shard_info_list)
    shard_count_tensor = torch.tensor(shard_count, dtype=torch.int64, device=torch_device)
    torch.distributed.broadcast(shard_count_tensor, src=0)
    shard_count = int(shard_count_tensor.item())

    shard_iterable = tqdm(
        range(shard_count),
        desc="Loading checkpoint shards",
        disable=int(os.getenv("LOCAL_RANK", "-1")) > 0,
    )

    for shard_id in shard_iterable:
        if shard_id == 0 and optim_state_dict is not None:
            if global_rank == 0:
                shard_metadata = extract_metadata(shard_info_list[shard_id], metadata)
                shard_state_dict = partial_load_dcp_state_dict(shard_metadata, storage_reader)
            else:
                shard_state_dict = {}

            broadcast_list = [shard_state_dict]
            torch.distributed.broadcast_object_list(broadcast_list, src=0)
            shard_state_dict = broadcast_list[0]

            for key, value in shard_state_dict[OPTIMIZER].items():
                if key in optim_state_dict:
                    optim_state_dict[key] = shard_state_dict[OPTIMIZER][key]
                    params_to_load.discard(key)

            load_state[OPTIMIZER].load_state_dict(optim_state_dict)
            continue

        if global_rank == 0:
            shard_metadata = extract_metadata(shard_info_list[shard_id], metadata)
            shard_state_dict = partial_load_dcp_state_dict(shard_metadata, storage_reader)

            param_info_list = []
            if MODEL in shard_state_dict or OPTIMIZER in shard_state_dict:
                for prefix in shard_state_dict:
                    prefix_state_dict = shard_state_dict[prefix]
                    param_info_list.extend([
                        ParamInfo(name=k, shape=v.shape, dtype=v.dtype, prefix=prefix)
                        for k, v in prefix_state_dict.items()
                    ])
            else:
                param_info_list.extend([
                    ParamInfo(name=k, shape=v.shape, dtype=v.dtype)
                    for k, v in shard_state_dict.items()
                ])
        else:
            param_info_list = []

        broadcast_list = [param_info_list]
        torch.distributed.broadcast_object_list(broadcast_list, src=0)
        param_info_list = broadcast_list[0]

        for param_info in param_info_list:
            param_name = param_info.name
            if param_name not in params_to_load:
                continue

            tensor = None
            if global_rank != 0:
                tensor = torch.empty(param_info.shape, dtype=param_info.dtype, device=torch_device)
            else:
                if param_info.prefix is None:
                    tensor = shard_state_dict[param_name].to(torch_device, non_blocking=True)
                else:
                    tensor = shard_state_dict[param_info.prefix][param_name].to(torch_device, non_blocking=True)

            torch.distributed.broadcast(tensor, src=0)

            params_to_load.discard(param_name)
            if param_info.prefix == OPTIMIZER:
                target_state_dict = optim_state_dict
            else:
                target_state_dict = model_state_dict
            target_tensor = target_state_dict[param_name]

            if isinstance(target_tensor, DTensor):
                device_mesh = getattr(target_tensor, "device_mesh", None)
                placements = getattr(target_tensor, "placements", None)
                target_state_dict[param_name].copy_(tensor_to_dtensor(tensor, device_mesh, placements))
            else:
                target_state_dict[param_name].copy_(tensor)

            del tensor

        gc.collect()
        empty_cache()

    if len(params_to_load) > 0:
        print_rank(logger.warning, f"These weights were not loaded from the checkpoint, param keys: {params_to_load}.")
    print_rank(logger.info, f"Finished loading and broadcasting checkpoint tensors from rank 0.")