import copy
from typing import Callable, Any, List, Optional, Generator
import operator
import numpy as np
def get_attrs_of_obj(obj, filter_func=None) -> List[Any]:
if filter_func is None:
return [getattr(obj, attr_name) for attr_name in dir(obj)]
else:
return [attr for attr in (getattr(obj, attr_name) for attr_name in dir(obj)) if filter_func(attr)]
def concatenate_name_in_network(name_in_network: Optional[str], sub_name: str) -> str:
if name_in_network is None or name_in_network == "":
return sub_name
else:
return name_in_network + "." + sub_name
class FullPermutation:
cache = [[[]], None, None, None, None, None, None, None]
@classmethod
def get_all_permutations(cls, max_index: int) -> Generator[List, None, None]:
"""
Get all possible permutations.
Args:
max_index: max index.
Examples:
>>> x = FullPermutation.get_all_permutations(3)
>>> print(list(x))
[[2, 1, 0], [1, 2, 0], [1, 0, 2], [2, 0, 1], [0, 2, 1], [0, 1, 2]]
"""
if not isinstance(max_index, int) or max_index < 0:
raise ValueError("index must be int")
if max_index < len(FullPermutation.cache):
if FullPermutation.cache[max_index] is not None:
return cls.cache[max_index]
return_lists = list(cls._get_all_permutations(max_index))
FullPermutation.cache[max_index] = return_lists
return return_lists
else:
return cls._get_all_permutations(max_index)
@classmethod
def get_all_combinations(cls, cnt_list: List[int]) -> Generator[List, None, None]:
"""
Get all possible combinations.
Args:
cnt_list: max index list.
Examples:
>>> x = FullPermutation.get_all_combinations([1, 2, 3])
>>> print(list(x))
[[0, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 1], [0, 0, 2], [0, 1, 2]]
"""
if cnt_list is None or len(cnt_list) == 0:
yield []
return
ret_combination = [0] * len(cnt_list)
while True:
yield copy.copy(ret_combination)
now_x = ret_combination[0]
if now_x + 1 < cnt_list[0]:
ret_combination[0] = now_x + 1
continue
for i, max_in_list in enumerate(cnt_list):
ret_combination[i] += 1
if ret_combination[i] >= max_in_list:
ret_combination[i] = 0
else:
break
else:
return
@classmethod
def _get_all_permutations(cls, max_index) -> Generator[List, None, None]:
for index_1_seq in cls.get_all_permutations(max_index - 1):
for location in range(max_index):
seq_return = copy.copy(index_1_seq)
seq_return.insert(location, max_index - 1)
yield seq_return
class CallParams:
"""
to save function call params
"""
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
@property
def args(self):
return self._args
@property
def kwargs(self):
return self._kwargs
def call(self, func: Callable):
return func(*self.args, **self.kwargs)
class ResListToRelease:
def __init__(self, *args):
self.res_list = args
def __enter__(self):
for res in self.res_list:
res.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for res in self.res_list:
res.__exit__(exc_type, exc_val, exc_tb)
def amp_enabled():
try:
from apex import amp
except ImportError:
return False
return hasattr(amp._amp_state, 'handle')
class OperatorAttrName:
attr_names = set([f"__{x}__" for x in dir(operator) if not x.startswith("__")])
def check_model_backend(model):
"""
Check model is a MindSpore or PyTorch model.
Args:
model: model instance
Returns: backend name
"""
if model is None:
raise ValueError("The model can't be None!")
try:
from torch.nn.modules import Module
except ModuleNotFoundError:
pass
else:
if isinstance(model, Module):
return "pytorch"
try:
from mindspore.nn.cell import Cell
except ModuleNotFoundError:
pass
else:
if isinstance(model, Cell):
return "mindspore"
raise ValueError("The model must be a MindSpore or PyTorch model, and with MindSpore or PyTorch environment!")
def count_parameters(network):
if hasattr(network, "parameters_dict"):
import mindspore as ms
if not isinstance(network, ms.nn.Cell):
raise TypeError("Provided network is not a mindspore.nn.Cell")
param_dict = network.parameters_dict(recurse=True)
return sum([np.prod(param.shape) for param in param_dict.values() if isinstance(param, ms.Parameter)])
elif hasattr(network, "state_dict"):
import torch
if not isinstance(network, torch.nn.Module):
raise TypeError("Provided network is not a torch.nn.Module")
return sum([np.prod(param.shape) for param in network.parameters()])
else:
raise AttributeError("network should be an instance of torch.nn.Module or mindspore.nn.Cell")