import math
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from torch.utils._foreach_utils import (
    _device_has_foreach_support,
    _group_tensors_by_device_and_dtype,
    _has_foreach_support,
)
from typing import List, Optional, Iterable
from mindspeed_llm.fsdp2.distributed.parallel_state import ParallelState
import logging

logger = logging.getLogger(__name__)


def clip_grad_norm(
    model, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None
) -> torch.Tensor:
    # EP-aware path (FSDP2 + EP): maintain mathematical parity with FSDP1 clipper
    if hasattr(model, "_ep_param_groups"):
        return ep_fsdp2_clip_grad_norm(
            model,
            max_norm,
            norm_type=norm_type,
            error_if_nonfinite=error_if_nonfinite,
            foreach=foreach,
        )

    grad_norm = fsdp2_clip_grad_norm(
        model.parameters(),
        max_norm,
        norm_type=norm_type,
        error_if_nonfinite=error_if_nonfinite,
        foreach=foreach,
    )
    if isinstance(grad_norm, DTensor):
        grad_norm = grad_norm.full_tensor()
    return grad_norm


@torch.no_grad()
def fsdp2_clip_grad_norm(
    parameters: torch.Tensor | Iterable[torch.Tensor],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: Optional[bool] = None,
) -> torch.Tensor:
    r"""
    Clip the gradient norm of parameters, with FSDP2 DTensor support.
    Safely handles mixed models with both DTensors and ordinary Tensors.
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    else:
        # prevent generators from being exhausted
        parameters = list(parameters)

    # Strictly separate DTensor and native Tensor
    dtensor_params = []
    tensor_params = []
    dtensor_grads = []
    tensor_grads = []

    for p in parameters:
        if p.grad is not None:
            if isinstance(p.grad, DTensor):
                dtensor_params.append(p)
                dtensor_grads.append(p.grad)
            else:
                tensor_params.append(p)
                tensor_grads.append(p.grad)

    # Compute gradient norms independently
    if len(dtensor_grads) > 0:
        norm_d = torch.nn.utils.get_total_norm(dtensor_grads, norm_type, error_if_nonfinite, foreach)
        if isinstance(norm_d, DTensor):
            norm_d = norm_d.full_tensor()
    else:
        norm_d = torch.tensor(0.0, device=parameters[0].device if parameters else 'cpu')

    if len(tensor_grads) > 0:
        norm_t = torch.nn.utils.get_total_norm(tensor_grads, norm_type, error_if_nonfinite, foreach)
    else:
        norm_t = torch.tensor(0.0, device=norm_d.device)

    # Merge two norms mathematically rigorously
    if norm_type >= math.inf:
        total_norm = torch.max(norm_d, norm_t)
    else:
        total_norm = (norm_d**norm_type + norm_t**norm_type) ** (1.0 / norm_type)

    # Context Parallel (CP) scale handling
    ps = ParallelState()
    total_norm *= ps.get_group_size("cp")

    # Execute clipping separately to prevent foreach_mul operator crash
    if len(dtensor_params) > 0:
        torch.nn.utils.clip_grads_with_norm_(dtensor_params, max_norm, total_norm, foreach)
    if len(tensor_params) > 0:
        torch.nn.utils.clip_grads_with_norm_(tensor_params, max_norm, total_norm, foreach)

    return total_norm


@torch.no_grad()
def ep_fsdp2_clip_grad_norm(
    model, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None
) -> torch.Tensor:
    """
    EP-aware gradient clipping for composable FSDP2 with reductions mirroring FSDP1:

    - Compute local norms for non-EP and EP parameter groups separately.
    - For finite p: sum p-th powers across the appropriate groups, then take 1/p.
      • non-EP: all-reduce over FSDP group.
      • EP: all-reduce over EP-FSDP group, then over EP group.
    - For inf-norm: take elementwise MAX with the same reduction groups (MAX).
    - Use a single global clip coefficient for both groups.
    """

    ps = ParallelState()
    fsdp_group = ps.get_fsdp_group()
    ep_group = ps.get_ep_group() if ps.is_ep_enable() else None
    # For EP params sharded by FSDP2 along hidden dimension
    ep_fsdp_group = None
    if ps.is_ep_enable() and ps.get_efsdp_device_mesh() is not None:
        ep_fsdp_group = ps.get_efsdp_group()

    # Build param groups (filter out params without grads)
    ep_params: List[torch.nn.Parameter] = [p for p in model._ep_param_groups.get("ep", []) if p.grad is not None]
    non_ep_params: List[torch.nn.Parameter] = [
        p for p in model._ep_param_groups.get("non_ep", []) if p.grad is not None
    ]

    # Compute and reduce non-EP
    non_ep_total = _fsdp2_reduce_group(
        params=non_ep_params,
        norm_type=norm_type,
        reduce_groups=[("fsdp", fsdp_group)],
    )
    # Compute and reduce EP: first across ep_fsdp, then across ep
    ep_total = _fsdp2_reduce_group(
        params=ep_params,
        norm_type=norm_type,
        reduce_groups=[("ep_fsdp", ep_fsdp_group), ("ep", ep_group)],
    )

    if math.isinf(norm_type):
        total_norm = torch.maximum(non_ep_total, ep_total)
    else:
        total_norm = (non_ep_total + ep_total) ** (1.0 / float(norm_type))

    # Apply the same clip coefficient to both groups
    torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach=False)
    torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach=False)

    return total_norm


# compute local sum of param guard norm
def _local_pth_sum(params: List[torch.nn.Parameter], p: float) -> torch.Tensor:
    grads = [p.grad for p in params if p.grad is not None]
    if not grads:
        # At this point, 0.0 on the current device needs to be returned; otherwise, an error may occur in the subsequent all_reduce operation.
        return torch.tensor(0.0, device=torch.accelerator.current_device(), dtype=torch.float32)

    grads_local = [
        g.to_local().detach().to(torch.float32) if isinstance(g, DTensor) else g.detach().to(torch.float32)
        for g in grads
    ]
    default_device = (
        grads_local[0].device if len(grads_local) > 0 else torch.device(torch.accelerator.current_accelerator().type)
    )
    res = torch.tensor(0.0, device=default_device, dtype=torch.float32)
    with torch.no_grad():
        grouped_grads_local = _group_tensors_by_device_and_dtype([grads_local])
        for (device, _), ([device_grads_local], _) in grouped_grads_local.items():
            if _has_foreach_support(device_grads_local, device) or _device_has_foreach_support(device):
                out = torch._foreach_pow_(torch._foreach_norm(device_grads_local, p), p)
                res += torch.sum(torch.stack(out)).to(default_device)
            else:
                for grad_local in device_grads_local:
                    gn = torch.norm(grad_local, p=p)
                    res = res + (gn**p).to(default_device)
    return res


def _local_max(params: List[torch.nn.Parameter]) -> torch.Tensor:
    dev = None
    mx = None
    for q in params:
        g = q.grad
        if g is None:
            continue
        if isinstance(g, DTensor):
            g_local = g.to_local()
        else:
            g_local = g
        if dev is None:
            dev = g_local.device
            mx = torch.tensor(0.0, device=dev, dtype=torch.float32)
        gn = torch.max(torch.abs(g_local.detach().to(torch.float32)))
        mx = torch.maximum(mx, gn)
    if mx is None:
        dev = torch.device(torch.accelerator.current_accelerator().type)
        mx = torch.tensor(0.0, device=dev, dtype=torch.float32)
    return mx


def _fsdp2_reduce_group(
    params: List[torch.nn.Parameter],
    norm_type: float,
    reduce_groups: List[tuple[str, dist.ProcessGroup | None]],
) -> torch.Tensor:
    """Compute local group statistic and reduce over provided groups.

    For finite p, returns the globally-reduced sum of p-th powers (not the final norm).
    For inf, returns the globally-reduced max.
    """
    if math.isinf(norm_type):
        val = _local_max(params)
        for _, group in reduce_groups:
            if group is not None:
                dist.all_reduce(val, op=dist.ReduceOp.MAX, group=group)
        return val
    else:
        p = float(norm_type)
        val = _local_pth_sum(params, p)
        for _, group in reduce_groups:
            if group is not None:
                dist.all_reduce(val, op=dist.ReduceOp.SUM, group=group)
        return val


def _get_rank0_log(msg):
    if dist.is_initialized() and dist.get_rank() == 0:
        logger.info(msg)