import math
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
from typing import List, Optional, Iterable
from mindspeed_llm.fsdp2.distributed.parallel_state import ParallelState
import logging
logger = logging.getLogger(__name__)
def clip_grad_norm(
model, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None
) -> torch.Tensor:
if hasattr(model, "_ep_param_groups"):
return ep_fsdp2_clip_grad_norm(
model,
max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
foreach=foreach,
)
grad_norm = fsdp2_clip_grad_norm(
model.parameters(),
max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
foreach=foreach,
)
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
return grad_norm
@torch.no_grad()
def fsdp2_clip_grad_norm(
parameters: torch.Tensor | Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""
Clip the gradient norm of parameters, with FSDP2 DTensor support.
Safely handles mixed models with both DTensors and ordinary Tensors.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
parameters = list(parameters)
dtensor_params = []
tensor_params = []
dtensor_grads = []
tensor_grads = []
for p in parameters:
if p.grad is not None:
if isinstance(p.grad, DTensor):
dtensor_params.append(p)
dtensor_grads.append(p.grad)
else:
tensor_params.append(p)
tensor_grads.append(p.grad)
if len(dtensor_grads) > 0:
norm_d = torch.nn.utils.get_total_norm(dtensor_grads, norm_type, error_if_nonfinite, foreach)
if isinstance(norm_d, DTensor):
norm_d = norm_d.full_tensor()
else:
norm_d = torch.tensor(0.0, device=parameters[0].device if parameters else 'cpu')
if len(tensor_grads) > 0:
norm_t = torch.nn.utils.get_total_norm(tensor_grads, norm_type, error_if_nonfinite, foreach)
else:
norm_t = torch.tensor(0.0, device=norm_d.device)
if norm_type >= math.inf:
total_norm = torch.max(norm_d, norm_t)
else:
total_norm = (norm_d**norm_type + norm_t**norm_type) ** (1.0 / norm_type)
ps = ParallelState()
total_norm *= ps.get_group_size("cp")
if len(dtensor_params) > 0:
torch.nn.utils.clip_grads_with_norm_(dtensor_params, max_norm, total_norm, foreach)
if len(tensor_params) > 0:
torch.nn.utils.clip_grads_with_norm_(tensor_params, max_norm, total_norm, foreach)
return total_norm
@torch.no_grad()
def ep_fsdp2_clip_grad_norm(
model, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: bool | None = None
) -> torch.Tensor:
"""
EP-aware gradient clipping for composable FSDP2 with reductions mirroring FSDP1:
- Compute local norms for non-EP and EP parameter groups separately.
- For finite p: sum p-th powers across the appropriate groups, then take 1/p.
• non-EP: all-reduce over FSDP group.
• EP: all-reduce over EP-FSDP group, then over EP group.
- For inf-norm: take elementwise MAX with the same reduction groups (MAX).
- Use a single global clip coefficient for both groups.
"""
ps = ParallelState()
fsdp_group = ps.get_fsdp_group()
ep_group = ps.get_ep_group() if ps.is_ep_enable() else None
ep_fsdp_group = None
if ps.is_ep_enable() and ps.get_efsdp_device_mesh() is not None:
ep_fsdp_group = ps.get_efsdp_group()
ep_params: List[torch.nn.Parameter] = [p for p in model._ep_param_groups.get("ep", []) if p.grad is not None]
non_ep_params: List[torch.nn.Parameter] = [
p for p in model._ep_param_groups.get("non_ep", []) if p.grad is not None
]
non_ep_total = _fsdp2_reduce_group(
params=non_ep_params,
norm_type=norm_type,
reduce_groups=[("fsdp", fsdp_group)],
)
ep_total = _fsdp2_reduce_group(
params=ep_params,
norm_type=norm_type,
reduce_groups=[("ep_fsdp", ep_fsdp_group), ("ep", ep_group)],
)
if math.isinf(norm_type):
total_norm = torch.maximum(non_ep_total, ep_total)
else:
total_norm = (non_ep_total + ep_total) ** (1.0 / float(norm_type))
torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach=False)
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach=False)
return total_norm
def _local_pth_sum(params: List[torch.nn.Parameter], p: float) -> torch.Tensor:
grads = [p.grad for p in params if p.grad is not None]
if not grads:
return torch.tensor(0.0, device=torch.accelerator.current_device(), dtype=torch.float32)
grads_local = [
g.to_local().detach().to(torch.float32) if isinstance(g, DTensor) else g.detach().to(torch.float32)
for g in grads
]
default_device = (
grads_local[0].device if len(grads_local) > 0 else torch.device(torch.accelerator.current_accelerator().type)
)
res = torch.tensor(0.0, device=default_device, dtype=torch.float32)
with torch.no_grad():
grouped_grads_local = _group_tensors_by_device_and_dtype([grads_local])
for (device, _), ([device_grads_local], _) in grouped_grads_local.items():
if _has_foreach_support(device_grads_local, device) or _device_has_foreach_support(device):
out = torch._foreach_pow_(torch._foreach_norm(device_grads_local, p), p)
res += torch.sum(torch.stack(out)).to(default_device)
else:
for grad_local in device_grads_local:
gn = torch.norm(grad_local, p=p)
res = res + (gn**p).to(default_device)
return res
def _local_max(params: List[torch.nn.Parameter]) -> torch.Tensor:
dev = None
mx = None
for q in params:
g = q.grad
if g is None:
continue
if isinstance(g, DTensor):
g_local = g.to_local()
else:
g_local = g
if dev is None:
dev = g_local.device
mx = torch.tensor(0.0, device=dev, dtype=torch.float32)
gn = torch.max(torch.abs(g_local.detach().to(torch.float32)))
mx = torch.maximum(mx, gn)
if mx is None:
dev = torch.device(torch.accelerator.current_accelerator().type)
mx = torch.tensor(0.0, device=dev, dtype=torch.float32)
return mx
def _fsdp2_reduce_group(
params: List[torch.nn.Parameter],
norm_type: float,
reduce_groups: List[tuple[str, dist.ProcessGroup | None]],
) -> torch.Tensor:
"""Compute local group statistic and reduce over provided groups.
For finite p, returns the globally-reduced sum of p-th powers (not the final norm).
For inf, returns the globally-reduced max.
"""
if math.isinf(norm_type):
val = _local_max(params)
for _, group in reduce_groups:
if group is not None:
dist.all_reduce(val, op=dist.ReduceOp.MAX, group=group)
return val
else:
p = float(norm_type)
val = _local_pth_sum(params, p)
for _, group in reduce_groups:
if group is not None:
dist.all_reduce(val, op=dist.ReduceOp.SUM, group=group)
return val
def _get_rank0_log(msg):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(msg)