# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# pylint: skip-file

import fnmatch
from dataclasses import dataclass, field
from functools import wraps
from typing import Callable, Tuple, Union

import torch

from mindspeed.args_utils import get_full_args


_MUON_OPTIMIZER_CONFIG_DEFAULTS = {
    "muon_momentum": 0.95,
    "muon_split_qkv": True,
    "muon_nesterov": False,
    "muon_scale_mode": "spectral",
    "muon_fp32_matmul_prec": "medium",
    "muon_coefficient_type": "quintic",
    "muon_num_ns_steps": 5,
    "muon_tp_mode": "blockwise",
    "muon_extra_scale_factor": 1.0,
    "muon_scalar_optimizer": "adam",
    "apply_wd_to_qk_layernorm": False,
    "use_layer_wise_distributed_optimizer": False,
}


def optimizer_config_init_wrapper(init_func):
    @wraps(init_func)
    def optimizer_config_init(*args, **kwargs):
        init_func(*args, **kwargs)
        self = args[0]
        args_namespace = get_full_args()
        for name, default in _MUON_OPTIMIZER_CONFIG_DEFAULTS.items():
            if not hasattr(self, name):
                setattr(self, name, getattr(args_namespace, name, default))

    return optimizer_config_init


@dataclass(frozen=True)
class ParamPredicate:
    """Wraps a matching function to make it hashable for ParamKey.
    Example:
        >>> shape_1_param = ParamPredicate(name="s1", fn=lambda param: len(param.shape) == 1)
        >>> shape_1_param(torch.empty(10))
        True
        >>> shape_1_param_copy = ParamPredicate(name="s1", fn=lambda param: len(param.shape) == 1)
        >>> shape_1_param == shape_1_param_copy  # name is used to match
        True
        >>> {shape_1_param, shape_1_param_copy} == {shape_1_param}  # set hashing works properly

    NOTE:
        __hash__ and __eq__ are automatically generated by @dataclass(frozen=True)
        based solely on 'name' because we set compare=False/hash=False on 'fn'.
    """

    name: str
    fn: Callable[[torch.nn.Parameter], bool] = field(compare=False, hash=False)

    def __call__(self, param: torch.nn.Parameter) -> bool:
        return self.fn(param)


@dataclass(frozen=True)
class ParamWithNamePredicate:
    """Wraps a matching function to make it hashable for ParamKey.
    Example:
        >>> shape_1_not_qkln_param = ParamWithNamePredicate(
                name="s1_not_qkln",
                fn=lambda param, name: (
                    len(param.shape) == 1 or name.endswith(".bias")
                    and not ("q_layernorm." in name or "k_layernorm." in name)
                )
            )
        >>> shape_1_not_qkln_param(torch.empty(10), "interesting.bias")
        True
        >>> shape_1_not_qkln_param(torch.empty(10), "interesting.q_layernorm.bias")
        False

    NOTE:
        __hash__ and __eq__ are automatically generated by @dataclass(frozen=True)
        based solely on 'name' because we set compare=False/hash=False on 'fn'.
    """

    name: str
    fn: Callable[[torch.nn.Parameter, str], bool] = field(compare=False, hash=False)

    def __call__(self, param: torch.nn.Parameter, name: str) -> bool:
        return self.fn(param, name)


@dataclass(frozen=True)
class ParamKey:
    """Key to group parameters by. All such grouped parameters can share an
    optimizer config specification.
    """

    # TODO: Can add layer_id here later.

    name: Union[str, Tuple[str]] = field(default_factory=tuple)
    """Parameter name(s), will use unix filesystem path syntax for matching."""

    attr: Union[str, Tuple[str]] = field(default_factory=tuple)
    """Parameter attribute(s)."""

    predicate: Union[ParamPredicate, Tuple[ParamPredicate]] = field(default_factory=tuple)
    """Predicate(s) to match parameters by. If multiple predicates are provided, any must match."""

    with_name_predicate: Union[ParamWithNamePredicate, Tuple[ParamWithNamePredicate]] = field(default_factory=tuple)
    """
    Predicate(s) to match parameters with their name. If multiple predicates are provided,
      any must match. This is useful if you need to filter out some parameters from an otherwise
      positive match by their name.
    """

    def matches(self, param: torch.nn.Parameter, param_name: str) -> bool:
        """Returns true if passed-in parameter (with name) matches `param_key`.

        Args:
            param (torch.nn.Parameter): Handle to parameter object.
            param_name (str): Name of parameter in underlying PyTorch module.

        Returns:
            bool: True if parameter matches passed-in param_key.
        """

        # Check if name matches.
        if isinstance(self.name, str):
            target_names = [self.name]
        else:
            target_names = list(self.name)
        for target_name in target_names:
            if fnmatch.fnmatch(param_name, target_name):
                return True

        # Check if attribute matches.
        if isinstance(self.attr, str):
            target_attrs = [self.attr]
        else:
            target_attrs = list(self.attr)
        for target_attr in target_attrs:
            if getattr(param, target_attr, False):
                return True

        # Check if predicate matches.
        if isinstance(self.predicate, ParamPredicate):
            if self.predicate(param):
                return True
        else:
            for predicate in self.predicate:
                if predicate(param):
                    return True

        # Check if with_name_predicate matches.
        if isinstance(self.with_name_predicate, ParamWithNamePredicate):
            if self.with_name_predicate(param, param_name):
                return True
        else:
            for predicate in self.with_name_predicate:
                if predicate(param, param_name):
                    return True
        return False