from typing import Any, Optional, List
import torch
class DebugBwd(torch.autograd.Function):
"""
通用型反向传播梯度调试自定义函数
功能:在前向传播原样输出输入,反向传播时打印梯度信息并支持断点调试
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
debug_rank: int = 0,
print_info: Optional[str] = None,
extra_info: Optional[List[Any]] = None,
) -> torch.Tensor:
"""
前向传播:保存上下文信息,直接返回输入张量
Args:
ctx: 自动微分上下文
x: 需要监控梯度的输入张量
debug_rank: 调试的进程所属进程组的rank
print_info: 打印的标识信息
extra_info: 额外的参数信息, 比如前向的输入
Returns:
原样返回输入张量 x
"""
ctx.save_for_backward(x)
ctx.debug_info = {
"debug_rank": debug_rank,
"print_info": print_info,
"extra_info": extra_info,
}
return x
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> tuple[torch.Tensor, None, None]:
"""
反向传播:打印梯度信息并触发断点,返回原始梯度
"""
x = ctx.saved_tensors[0]
debug_rank = ctx.debug_info["debug_rank"]
print_info = ctx.debug_info["print_info"]
extra_info = ctx.debug_info["extra_info"]
rank_id = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if debug_rank == rank_id or debug_rank < 0:
log_parts = [f"[DebugBwd] {print_info or ' '}"]
log_parts.append(f"shape={grad_output.shape}")
log_parts.append(f"device={grad_output.device}")
log_parts.append(f"grad_sum={grad_output.sum():.6f}")
print_str = " | ".join(log_parts)
print(f"\n{print_str}\nGrad Output: {grad_output}\n", flush=True)
return grad_output, None, None, None
def debug_fn(
x: torch.Tensor,
debug_rank: int = 0,
print_info: Optional[str] = None,
extra_info: Optional[List[Any]] = None,
) -> torch.Tensor:
"""
【对外易用接口】梯度调试包装函数
用法:直接包裹需要监控梯度的张量,不改变前向计算逻辑
Example:
x = debug_fn(x, print_info="conv1_input") # 获取x的反向梯度
out = model(x)
"""
return DebugBwd.apply(x, debug_rank, print_info, extra_info)