"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import numpy as np
import operator
from unittest.mock import patch, MagicMock
from ascend_utils.common.utils import (
get_attrs_of_obj,
concatenate_name_in_network,
FullPermutation,
CallParams,
ResListToRelease,
amp_enabled,
OperatorAttrName,
check_model_backend,
count_parameters
)
class TestObj:
def __init__(self):
self.attr1 = "value1"
self.attr2 = 42
self._private_attr = "private"
def test_get_attrs_of_obj_given_obj_when_no_filter_then_return_all_attrs():
obj = TestObj()
attrs = get_attrs_of_obj(obj)
assert "attr1" in str(attrs)
assert "attr2" in str(attrs)
assert "_private_attr" in str(attrs)
def test_get_attrs_of_obj_given_obj_when_filter_then_return_filtered_attrs():
obj = TestObj()
attrs = get_attrs_of_obj(obj, filter_func=lambda x: isinstance(x, str))
assert all(isinstance(attr, str) for attr in attrs)
def test_concatenate_name_in_network_given_none_and_name_when_called_then_return_subname():
result = concatenate_name_in_network(None, "subname")
assert result == "subname"
def test_concatenate_name_in_network_given_empty_and_name_when_called_then_return_subname():
result = concatenate_name_in_network("", "subname")
assert result == "subname"
def test_concatenate_name_in_network_given_name_and_subname_when_called_then_return_concatenated():
result = concatenate_name_in_network("parent", "child")
assert result == "parent.child"
def test_fullpermutation_get_all_permutations_given_negative_when_called_then_raise_error():
with pytest.raises(ValueError):
list(FullPermutation.get_all_permutations(-1))
def test_fullpermutation_get_all_permutations_given_non_int_when_called_then_raise_error():
with pytest.raises(ValueError):
list(FullPermutation.get_all_permutations("3"))
def test_fullpermutation_get_all_permutations_given_zero_when_called_then_return_empty_list():
result = list(FullPermutation.get_all_permutations(0))
assert result == [[]]
def test_fullpermutation_get_all_permutations_given_three_when_called_then_return_all_permutations():
result = list(FullPermutation.get_all_permutations(3))
assert len(result) == 6
assert [2, 1, 0] in result
assert [0, 1, 2] in result
def test_fullpermutation_get_all_combinations_given_empty_list_when_called_then_return_empty_list():
result = list(FullPermutation.get_all_combinations([]))
assert result == [[]]
def test_fullpermutation_get_all_combinations_given_single_item_when_called_then_return_single_combinations():
result = list(FullPermutation.get_all_combinations([2]))
assert result == [[0], [1]]
def test_fullpermutation_get_all_combinations_given_multiple_items_when_called_then_return_all_combinations():
result = list(FullPermutation.get_all_combinations([1, 2]))
assert len(result) == 2
assert [0, 0] in result
assert [0, 1] in result
def test_callparams_init_given_args_kwargs_when_called_then_store_values():
params = CallParams(1, 2, a=3, b=4)
assert params.args == (1, 2)
assert params.kwargs == {"a": 3, "b": 4}
def test_callparams_call_given_function_when_called_then_execute_function():
def test_func(a, b, c=0):
return a + b + c
params = CallParams(1, 2, c=3)
result = params.call(test_func)
assert result == 6
class MockResource:
def __enter__(self):
self.entered = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.exited = True
def test_reslisttorelease_enter_given_resources_when_called_then_enter_all():
res1 = MockResource()
res2 = MockResource()
with ResListToRelease(res1, res2) as rl:
assert res1.entered
assert res2.entered
def test_reslisttorelease_exit_given_resources_when_called_then_exit_all():
res1 = MockResource()
res2 = MockResource()
rl = ResListToRelease(res1, res2)
rl.__enter__()
rl.__exit__(None, None, None)
assert hasattr(res1, 'exited')
assert hasattr(res2, 'exited')
def test_amp_enabled_given_apex_available_when_amp_enabled_then_return_true():
with patch.dict('sys.modules', {'apex': MagicMock()}):
with patch('apex.amp._amp_state.handle', True):
assert amp_enabled()
def test_amp_enabled_given_apex_not_available_when_called_then_return_false():
with patch.dict('sys.modules', {'apex': None}):
assert not amp_enabled()
def test_operatorattrname_attr_names_when_accessed_then_contains_operator_methods():
assert "__add__" in OperatorAttrName.attr_names
assert "__sub__" in OperatorAttrName.attr_names
def test_check_model_backend_given_none_when_called_then_raise_error():
with pytest.raises(ValueError):
check_model_backend(None)
def test_check_model_backend_given_pytorch_model_when_called_then_return_pytorch():
mock_torch = MagicMock()
mock_module = MagicMock()
mock_module.Module = object
with patch.dict('sys.modules', {'torch.nn.modules': mock_module}):
model = mock_torch.nn.Module()
assert check_model_backend(model) == "pytorch"
def test_check_model_backend_given_mindspore_model_when_called_then_return_mindspore():
mock_ms = MagicMock()
mock_cell = MagicMock()
mock_cell.Cell = object
with patch.dict('sys.modules', {'mindspore.nn.cell': mock_cell}):
model = mock_ms.nn.Cell()
assert check_model_backend(model) == "mindspore"
def test_check_model_backend_given_invalid_model_when_called_then_raise_error():
with pytest.raises(ValueError):
check_model_backend("not a model")
def test_count_parameters_given_invalid_network_when_called_then_raise_error():
with pytest.raises(AttributeError):
count_parameters("not a network")