# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect
import logging
import re
import sys
from collections.abc import Callable
from dataclasses import dataclass

import torch.nn as nn

logger = logging.getLogger(__name__)


@dataclass
class ModuleMatch:
    parent: nn.Module
    attr_name: str
    module: nn.Module
    full_name: str

    @property
    def is_meta(self) -> bool:
        p = next(self.module.parameters(), None)
        return p is not None and p.device.type == "meta"

    def replace(self, new_module: nn.Module, log: bool = False) -> None:
        setattr(self.parent, self.attr_name, new_module)
        if log:
            logger.info(f"   {self.full_name}")


@dataclass
class FunctionMatch:
    module_path: str
    func_name: str
    func: Callable

    @property
    def full_path(self) -> str:
        return f"{self.module_path}.{self.func_name}"

    def replace(self, new_func: Callable, log: bool = False) -> None:
        if mod := sys.modules.get(self.module_path):
            setattr(mod, self.func_name, new_func)
            if log:
                logger.info(f"   {self.full_path}")


@dataclass
class MethodMatch:
    module_path: str
    class_name: str
    method_name: str
    cls: type
    method: Callable

    @property
    def full_path(self) -> str:
        return f"{self.module_path}.{self.class_name}.{self.method_name}"

    def replace(self, new_method: Callable, log: bool = False) -> None:
        setattr(self.cls, self.method_name, new_method)
        if log:
            logger.info(f"   {self.full_path}")


def _get_package(model: nn.Module | None = None, package: str | None = None) -> str:
    if package:
        return package
    if model:
        return model.__class__.__module__.rsplit(".", 1)[0]
    raise ValueError("Must provide either model or package")


def find_modules(model: nn.Module, pattern: str) -> list[ModuleMatch]:
    regex = re.compile(pattern)
    return [
        ModuleMatch(
            model.get_submodule(name.rsplit(".", 1)[0]) if "." in name else model,
            name.rsplit(".", 1)[-1] if "." in name else name,
            module,
            name,
        )
        for name, module in model.named_modules()
        if regex.search(module.__class__.__name__) and name
    ]


def find_functions(
    func_name: str,
    model: nn.Module | None = None,
    package: str | None = None,
) -> list[FunctionMatch]:
    pkg = _get_package(model, package)
    result = []
    for path, mod in sys.modules.items():
        if not mod or not path.startswith(pkg):
            continue
        func = getattr(mod, func_name, None)
        if callable(func) and not isinstance(func, type):
            result.append(FunctionMatch(path, func_name, func))
    return result


def find_methods(
    class_name: str,
    method_name: str,
    model: nn.Module | None = None,
    package: str | None = None,
) -> list[MethodMatch]:
    pkg = _get_package(model, package)
    matches = []

    for mod_path, mod in sys.modules.items():
        if mod is None or not mod_path.startswith(pkg):
            continue

        cls = getattr(mod, class_name, None)
        if (
            inspect.isclass(cls)
            and cls.__module__ == mod_path
            and hasattr(cls, method_name)
        ):
            matches.append(
                MethodMatch(
                    module_path=mod_path,
                    class_name=class_name,
                    method_name=method_name,
                    cls=cls,
                    method=getattr(cls, method_name),
                )
            )

    return matches


def replace_modules(model: nn.Module, pattern: str, factory: Callable) -> int:
    matches = find_modules(model, pattern)
    for m in matches:
        try:
            m.replace(factory(m.module))
        except Exception as e:
            logger.error(f"  ✗ {m.full_name}: {e}")
    return len(matches)


def replace_functions(
    func_name: str,
    new_func: Callable,
    model: nn.Module | None = None,
    package: str | None = None,
) -> int:
    matches = find_functions(func_name, model=model, package=package)
    for m in matches:
        m.replace(new_func)
    return len(matches)


def replace_methods(
    class_name: str,
    method_name: str,
    new_method: Callable,
    model: nn.Module | None = None,
    package: str | None = None,
) -> int:
    matches = find_methods(class_name, method_name, model=model, package=package)
    for m in matches:
        m.replace(new_method)
    return len(matches)