"""Numpy compatibility patches for numpy 1.24+ removed aliases."""
from typing import List
from mx_driving.patcher.patch import BasePatch, LegacyPatch, Patch
class NumpyCompat(Patch):
"""Numpy compatibility patch for deprecated aliases (np.bool, np.float, np.int)."""
name = "numpy_compat"
legacy_name = "numpy_type"
target_module = "numpy"
apply_before_collect = True
@staticmethod
def _patch_deprecated_aliases(np_module, _options):
"""
1.0 等效写法:显式补齐缺失别名。
这样 `np.bool` / `np.float` / `np.int` 会真实存在于
`numpy.__dict__`(`dir()` 也可见),
避免 hook `numpy.__getattr__` 带来的潜在边界差异。
"""
changed = False
if not hasattr(np_module, "bool"):
np_module.bool = bool
changed = True
if not hasattr(np_module, "float"):
np_module.float = float
changed = True
if not hasattr(np_module, "int"):
np_module.int = int
changed = True
if not changed:
raise AttributeError(
"deprecated aliases already exist in this NumPy runtime; "
"numpy_compat is not needed"
)
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
LegacyPatch(cls._patch_deprecated_aliases, target_module="numpy", options=options),
]