import os.path
from collections import defaultdict
from functools import lru_cache
from unittest.mock import patch

from safetensors import safe_open
from torch import nn

from msmodelslim.utils.security import json_safe_load


class LayerwiseMixin:
    @lru_cache(maxsize=1)
    def get_weight_map(self):
        index_path = os.path.join(self.model_path, "model.safetensors.index.json")
        model_index = json_safe_load(index_path)
        return model_index["weight_map"]

    def get_state_dict(self, module: nn.Module, prefix: str = ""):
        weight_map = self.get_weight_map()
        file_to_names = defaultdict(list)
        for name, _ in module.named_parameters():
            full_name = f"{prefix}.{name}" if prefix else name
            if full_name in weight_map:
                file_to_names[weight_map[full_name]].append(name)

        state_dict = {}
        for file_name, names in file_to_names.items():
            file_path = os.path.join(self.model_path, file_name)
            with safe_open(file_path, framework="pt", device="cpu") as f:
                for name in names:
                    full_name = f"{prefix}.{name}" if prefix else name
                    state_dict[name] = f.get_tensor(full_name)
        return state_dict

    def load_decoder_if_not_exist(self, model: nn.Module, name: str, idx: int):
        try:
            return model.get_submodule(name)
        except AttributeError:
            with patch.object(nn.Linear, "reset_parameters", lambda _self: None):
                module_list: nn.ModuleList = model.model.layers
                template_module = module_list[0]
                decoder = template_module.__class__(config=self.config, layer_idx=idx)
                state_dict = self.get_state_dict(decoder, prefix=name)
                decoder.load_state_dict(state_dict)
                decoder.eval()
                module_list.append(decoder)
                return decoder

    def generate_decoder_layer(self, model: nn.Module):
        for idx in range(self.config.num_hidden_layers):
            name = f"model.layers.{idx}"
            yield name, self.load_decoder_if_not_exist(model, name=name, idx=idx)