"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
"""
msmodelslim.utils.patch.torch 模块的单元测试
"""
from unittest.mock import Mock
import torch
from torch import nn
import pytest
from msmodelslim.utils.patch.torch import (
patch_torch,
_is_torch_nn_module_has_get_submodule,
_is_torch_nn_module_has_set_submodule,
_is_torch_has_get_default_device,
_TORCH_DEFAULT_DEVICE,
)
class TestTorchPatch:
@staticmethod
def test_patch_keeps_get_submodule_when_existing():
if not hasattr(nn.Module, "get_submodule"):
pytest.skip("当前 PyTorch 无原生 get_submodule,跳过此用例")
original_method = nn.Module.get_submodule
patch_torch()
assert (
nn.Module.get_submodule is original_method
), "补丁不应覆盖原生 get_submodule"
@staticmethod
def test_patch_keeps_set_submodule_when_existing():
if not hasattr(nn.Module, "set_submodule"):
pytest.skip("当前 PyTorch 无原生 set_submodule,跳过此用例")
original_method = nn.Module.set_submodule
patch_torch()
assert (
nn.Module.set_submodule is original_method
), "补丁不应覆盖原生 set_submodule"
@staticmethod
def test_patch_adds_get_default_device_when_missing():
if hasattr(torch, "get_default_device"):
delattr(torch, "get_default_device")
assert not _is_torch_has_get_default_device(), "初始状态应无 get_default_device"
patch_torch()
assert _is_torch_has_get_default_device(), "补丁应补充 get_default_device"
assert torch.get_default_device() == torch.device("cpu"), "初始默认设备应为 CPU"
target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_device(target_device)
assert torch.get_default_device() == target_device, "默认设备应更新成功"
@staticmethod
def test_patch_keeps_get_default_device_when_existing():
if not hasattr(torch, "get_default_device"):
pytest.skip("当前 PyTorch 无原生 get_default_device,跳过此用例")
original_method = torch.get_default_device
patch_torch()
assert (
torch.get_default_device is original_method
), "补丁不应覆盖原生 get_default_device"
@staticmethod
def test_patch_adds_get_submodule_when_missing(mock_self):
if hasattr(nn.Module, "get_submodule"):
delattr(nn.Module, "get_submodule")
assert not _is_torch_nn_module_has_get_submodule(), "初始状态应无 get_submodule"
patch_torch()
assert _is_torch_nn_module_has_get_submodule(), "补丁应补充 get_submodule"
assert isinstance(mock_self.test_model.get_submodule("conv"), nn.Conv2d)
assert isinstance(
mock_self.test_model.get_submodule("inner.linear"), nn.Linear
)
assert (
mock_self.test_model.get_submodule("inner.non_exist") is None
)
@staticmethod
def test_patch_adds_set_submodule_when_missing(mock_self):
if hasattr(nn.Module, "set_submodule"):
delattr(nn.Module, "set_submodule")
assert not _is_torch_nn_module_has_set_submodule(), "初始状态应无 set_submodule"
patch_torch()
assert _is_torch_nn_module_has_set_submodule(), "补丁应补充 set_submodule"
new_linear = nn.Linear(5, 2)
mock_self.test_model.set_submodule("inner.linear", new_linear)
assert mock_self.test_model.inner.linear is new_linear, "嵌套子模块应设置成功"
new_conv = nn.Conv2d(16, 32, 3)
mock_self.test_model.set_submodule("conv", new_conv)
assert mock_self.test_model.conv is new_conv, "顶层子模块应设置成功"
@pytest.fixture
def mock_self(self):
mock = Mock()
"""测试前准备:保存原始方法+初始化测试模型"""
mock.original_get_submodule = getattr(nn.Module, "get_submodule", None)
mock.original_set_submodule = getattr(nn.Module, "set_submodule", None)
mock.original_get_default_device = getattr(torch, "get_default_device", None)
mock.original_set_default_device = getattr(torch, "set_default_device", None)
class InnerModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.relu = nn.ReLU()
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.inner = InnerModel()
self.conv = nn.Conv2d(3, 16, 3)
mock.test_model = TestModel()
yield mock
for name, original in [
("get_submodule", mock.original_get_submodule),
("set_submodule", mock.original_set_submodule),
]:
if original is not None:
setattr(nn.Module, name, original)
elif hasattr(nn.Module, name):
delattr(nn.Module, name)
for name, original in [
("get_default_device", mock.original_get_default_device),
("set_default_device", mock.original_set_default_device),
]:
if original is not None:
setattr(torch, name, original)
elif hasattr(torch, name):
delattr(torch, name)
global _TORCH_DEFAULT_DEVICE
_TORCH_DEFAULT_DEVICE = torch.device("cpu")