import functools
import inspect
import fnmatch
from typing import List, Dict, Any, Union, Callable, Tuple
import contextvars
import torch
import torch.nn as nn
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools.timestep.manager import TimestepManager
from msmodelslim.utils.logging import logger
MAX_RECURSION_DEPTH = 20
class InputCapture:
"""Handles capturing and storing function inputs and outputs."""
_captured_inputs_var = contextvars.ContextVar("captured_inputs", default=[])
@classmethod
def reset(cls) -> None:
"""Reset all captured inputs."""
cls._captured_inputs_var.set([])
@classmethod
def get_all(cls) -> List[Dict[str, Any]]:
"""Get all captured inputs."""
return cls._captured_inputs_var.get()
@classmethod
def add_record(cls, record: Dict[str, Any]) -> None:
"""Add a new record to the captured inputs."""
inputs = cls._captured_inputs_var.get()
inputs.append(record)
cls._captured_inputs_var.set(inputs)
@classmethod
def capture_forward_inputs(
cls,
func: Callable,
capture_mode: str = 'args',
) -> Callable:
"""
Decorator to capture inputs to a forward function.
Args:
func: Forward function to decorate
capture_mode: 'args', 'kwargs', 'timestep'
Returns:
Wrapped function
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
is_method = 'self' in sig.parameters
captured_args = list(bound.args[1:]) if is_method else list(bound.args)
captured_kwargs = bound.arguments.copy()
if is_method and 'self' in captured_kwargs:
del captured_kwargs['self']
if capture_mode == 'args':
captured_kwargs = {}
record = captured_args
elif capture_mode == 'kwargs':
captured_args = []
record = captured_kwargs
elif capture_mode == 'timestep':
record = {
"tag": "",
"timestep_idx": TimestepManager.get_timestep_idx(),
"module_name": func.__qualname__,
"args": captured_args,
"kwargs": captured_kwargs
}
else:
raise ValueError(f"Invalid capture_mode: {capture_mode}. Must be 'args' or 'kwargs' or 'timestep'")
result = func(*args, **kwargs)
record = to_device(record, device='cpu')
cls.add_record(record)
return result
return wrapper
class DumperManager(nn.Module):
"""Module that listens to and captures forward pass inputs and outputs."""
def __init__(
self,
module: nn.Module,
capture_mode: str = 'args',
):
"""
Initialize a listener for the given module.
Args:
module: Module to listen to
capture_mode: 'args' or 'kwargs' or 'timestep'
"""
super().__init__()
self.module = module
self.capture_mode = capture_mode
self.old_forward = None
if capture_mode not in {'args', 'kwargs', 'timestep'}:
raise ValueError(f"Invalid capture_mode: {capture_mode}. Must be 'args' or 'kwargs' or 'timestep'")
self._add_hook(self.module)
def save(self, path: str = '__output.pth') -> List[Dict[str, Any]]:
"""Save captured data and restore original forward method."""
data = InputCapture.get_all()
torch.save(data, path)
if self.old_forward:
self.module.forward = self.old_forward
self.old_forward = None
logger.info('Captured data saved to: %r', path)
return data
def reset(self) -> None:
"""Reset captured inputs."""
InputCapture.reset()
def _add_hook(self, module: nn.Module) -> Callable:
"""Add forward hook to the module."""
self.old_forward = module.forward
wrapper = InputCapture.capture_forward_inputs(
self.old_forward,
capture_mode=self.capture_mode,
)
module.forward = wrapper
return wrapper
def get_rank():
"""
Get the rank of the current process.
Returns:
int: Non-negative rank (in default group) if distributed is initialized; -1 otherwise.
"""
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return -1
def get_disable_layer_names(model: nn.Module,
layer_include: Union[List[str], Tuple[str], str],
layer_exclude: Union[List[str], Tuple[str], str]) -> List[str]:
"""
Get the names of layers to be disabled based on inclusion and exclusion patterns using fnmatch.
Args:
model: The neural network module
layer_include: Patterns for layers to include. Can be a string, list or tuple of strings.
layer_exclude: Patterns for layers to exclude. Can be a string, list or tuple of strings.
Returns:
List of layer names that should be disabled for quantization.
"""
if isinstance(layer_include, str):
layer_include = [layer_include]
if isinstance(layer_exclude, str):
layer_exclude = [layer_exclude]
all_layer_names = []
quant_layer_names = set()
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear):
all_layer_names.append(name)
if layer_include and not any(fnmatch.fnmatch(name, pattern) for pattern in layer_include):
continue
if layer_exclude and any(fnmatch.fnmatch(name, pattern) for pattern in layer_exclude):
continue
quant_layer_names.add(name)
disable_layer_names = [name for name in all_layer_names if name not in quant_layer_names]
return disable_layer_names
def to_device(data, device, depth=0):
""" recursive function to move data to the specified device """
if depth > MAX_RECURSION_DEPTH:
raise RecursionError(f"Maximum recursion depth {MAX_RECURSION_DEPTH} exceeded")
if isinstance(data, dict):
return {k: to_device(v, device, depth=depth + 1) for k, v in data.items()}
elif isinstance(data, list):
return [to_device(item, device, depth=depth + 1) for item in data]
elif isinstance(data, tuple):
return tuple(to_device(item, device, depth=depth + 1) for item in data)
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
def get_rank_suffix_file(base_name, ext, is_distributed, rank):
"""
生成带rank后缀的文件名,分布式环境下添加_rank标识,非分布式环境直接使用基础名称
参数:
base_name (str): 文件名基础部分(不含后缀)
ext (str): 文件后缀(不含小数点)
is_distributed (bool): 是否为分布式环境
rank (int): 当前进程的rank值(分布式环境下有效)
返回:
str: 处理后的完整文件名(含后缀)
"""
if is_distributed:
return f"{base_name}_{rank}.{ext}"
return f"{base_name}.{ext}"