import pytest
import numpy as np
import operator
from unittest.mock import patch, MagicMock
from ascend_utils.common.utils import (
get_attrs_of_obj,
concatenate_name_in_network,
FullPermutation,
CallParams,
ResListToRelease,
amp_enabled,
OperatorAttrName,
check_model_backend,
count_parameters
)
class TestObj:
def __init__(self):
self.attr1 = "value1"
self.attr2 = 42
self._private_attr = "private"
def test_get_attrs_of_obj_given_obj_when_no_filter_then_return_all_attrs():
obj = TestObj()
attrs = get_attrs_of_obj(obj)
assert "attr1" in str(attrs)
assert "attr2" in str(attrs)
assert "_private_attr" in str(attrs)
def test_get_attrs_of_obj_given_obj_when_filter_then_return_filtered_attrs():
obj = TestObj()
attrs = get_attrs_of_obj(obj, filter_func=lambda x: isinstance(x, str))
assert all(isinstance(attr, str) for attr in attrs)
def test_concatenate_name_in_network_given_none_and_name_when_called_then_return_subname():
result = concatenate_name_in_network(None, "subname")
assert result == "subname"
def test_concatenate_name_in_network_given_empty_and_name_when_called_then_return_subname():
result = concatenate_name_in_network("", "subname")
assert result == "subname"
def test_concatenate_name_in_network_given_name_and_subname_when_called_then_return_concatenated():
result = concatenate_name_in_network("parent", "child")
assert result == "parent.child"
def test_fullpermutation_get_all_permutations_given_negative_when_called_then_raise_error():
with pytest.raises(ValueError):
list(FullPermutation.get_all_permutations(-1))
def test_fullpermutation_get_all_permutations_given_non_int_when_called_then_raise_error():
with pytest.raises(ValueError):
list(FullPermutation.get_all_permutations("3"))
def test_fullpermutation_get_all_permutations_given_zero_when_called_then_return_empty_list():
result = list(FullPermutation.get_all_permutations(0))
assert result == [[]]
def test_fullpermutation_get_all_permutations_given_three_when_called_then_return_all_permutations():
result = list(FullPermutation.get_all_permutations(3))
assert len(result) == 6
assert [2, 1, 0] in result
assert [0, 1, 2] in result
def test_fullpermutation_get_all_combinations_given_empty_list_when_called_then_return_empty_list():
result = list(FullPermutation.get_all_combinations([]))
assert result == [[]]
def test_fullpermutation_get_all_combinations_given_single_item_when_called_then_return_single_combinations():
result = list(FullPermutation.get_all_combinations([2]))
assert result == [[0], [1]]
def test_fullpermutation_get_all_combinations_given_multiple_items_when_called_then_return_all_combinations():
result = list(FullPermutation.get_all_combinations([1, 2]))
assert len(result) == 2
assert [0, 0] in result
assert [0, 1] in result
def test_callparams_init_given_args_kwargs_when_called_then_store_values():
params = CallParams(1, 2, a=3, b=4)
assert params.args == (1, 2)
assert params.kwargs == {"a": 3, "b": 4}
def test_callparams_call_given_function_when_called_then_execute_function():
def test_func(a, b, c=0):
return a + b + c
params = CallParams(1, 2, c=3)
result = params.call(test_func)
assert result == 6
class MockResource:
def __enter__(self):
self.entered = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.exited = True
def test_reslisttorelease_enter_given_resources_when_called_then_enter_all():
res1 = MockResource()
res2 = MockResource()
with ResListToRelease(res1, res2) as rl:
assert res1.entered
assert res2.entered
def test_reslisttorelease_exit_given_resources_when_called_then_exit_all():
res1 = MockResource()
res2 = MockResource()
rl = ResListToRelease(res1, res2)
rl.__enter__()
rl.__exit__(None, None, None)
assert hasattr(res1, 'exited')
assert hasattr(res2, 'exited')
def test_amp_enabled_given_apex_available_when_amp_enabled_then_return_true():
with patch.dict('sys.modules', {'apex': MagicMock()}):
with patch('apex.amp._amp_state.handle', True):
assert amp_enabled()
def test_amp_enabled_given_apex_not_available_when_called_then_return_false():
with patch.dict('sys.modules', {'apex': None}):
assert not amp_enabled()
def test_operatorattrname_attr_names_when_accessed_then_contains_operator_methods():
assert "__add__" in OperatorAttrName.attr_names
assert "__sub__" in OperatorAttrName.attr_names
def test_check_model_backend_given_none_when_called_then_raise_error():
with pytest.raises(ValueError):
check_model_backend(None)
def test_check_model_backend_given_pytorch_model_when_called_then_return_pytorch():
mock_torch = MagicMock()
mock_module = MagicMock()
mock_module.Module = object
with patch.dict('sys.modules', {'torch.nn.modules': mock_module}):
model = mock_torch.nn.Module()
assert check_model_backend(model) == "pytorch"
def test_check_model_backend_given_mindspore_model_when_called_then_return_mindspore():
mock_ms = MagicMock()
mock_cell = MagicMock()
mock_cell.Cell = object
with patch.dict('sys.modules', {'mindspore.nn.cell': mock_cell}):
model = mock_ms.nn.Cell()
assert check_model_backend(model) == "mindspore"
def test_check_model_backend_given_invalid_model_when_called_then_raise_error():
with pytest.raises(ValueError):
check_model_backend("not a model")
def test_count_parameters_given_invalid_network_when_called_then_raise_error():
with pytest.raises(AttributeError):
count_parameters("not a network")