Hh00855061cleanCode2
df32a54a创建于 2024年11月5日历史提交
# Copyright Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.

import importlib
import inspect
import logging


class FunctionReplace:
    def __init__(self, ori_function_define, new_function):
        self.new_function = new_function
        self.ori_function = None
        self.location = None
        self.attr_name = None
        if ori_function_define is None:
            return
        elif inspect.isfunction(ori_function_define):
            self.ori_function = ori_function_define
            self.location, self.attr_name = self.get_location(self.ori_function)
        elif callable(ori_function_define):
            self.ori_function = ori_function_define.__call__
            self.location, self.attr_name = self.get_location(self.ori_function)
        elif isinstance(ori_function_define, tuple) and len(ori_function_define) == 2:
            self.location, self.attr_name = ori_function_define
            self.ori_function = getattr(self.location, self.attr_name)

        if not all((self.ori_function, self.location, self.attr_name, self.new_function)):
            warn_msg = "{} replace failed.".format(ori_function_define)
            logging.warning(warn_msg)
            raise ValueError(warn_msg)

    def __enter__(self):
        self.replace()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.recover()

    @staticmethod
    def nothing():
        return FunctionReplace(None, None)

    @staticmethod
    def get_location(function_ins):
        if not hasattr(function_ins, "__module__"):
            warning_msg = "function {} do not has attr __module__.".format(str(function_ins))
            logging.warning(warning_msg)
            raise ValueError(warning_msg)
        module = importlib.import_module(function_ins.__module__)
        qualified_name = function_ins.__qualname__.split('.')
        classes = qualified_name[:-1]
        attr_name = qualified_name[-1]
        location = module
        for class_name in classes:
            location = getattr(location, class_name, None)
            if location is None:
                break

        if location is None:
            warning_msg = "{} do not exists".format('.'.join(classes))
            logging.warning(warning_msg)
            raise ValueError(warning_msg)
        return location, attr_name

    def replace(self):
        if all((self.ori_function, self.location, self.attr_name, self.new_function)):
            setattr(self.location, self.attr_name, self._get_method(self.new_function))

    def recover(self):
        if all((self.ori_function, self.location, self.attr_name, self.new_function)):
            setattr(self.location, self.attr_name, self._get_method(self.ori_function))

    def _get_method(self, func):
        if inspect.isclass(self.location):
            func_cls_name = inspect.getattr_static(self.location, self.attr_name).__class__.__name__
            if func_cls_name in ('staticmethod', 'classmethod'):
                return staticmethod(func)
        return func