"""internal graph-compatible utility functions"""
import math
from itertools import zip_longest, accumulate
from collections import deque
import operator
import mindspore.context as context
from ..ops import functional as F
from ..ops.primitive import constexpr
from ..common import dtype as mstype
from ..common import Tensor
from .._c_expression import Tensor as Tensor_
from .._c_expression import typing
from .._checkparam import Validator as validator
from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
_check_axis_type = constexpr(validator.check_axis_type)
@constexpr
def _check_shape(shape):
"""check the shape param to match the numpy style"""
if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)):
raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}")
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, (list, typing.List)):
shape = tuple(shape)
for s in shape:
if not isinstance(s, int):
raise TypeError("each entry in shape should be int.")
if s < 0:
raise ValueError("each entry in shape should no less than 0.")
return shape
@constexpr
def _check_dtype(dtype):
"""check the input dtype and make conversions"""
if isinstance(dtype, str):
dtype = dtype.lower()
dtype = dtype_map[dtype]
elif isinstance(dtype, type):
if dtype is int:
dtype = mstype.int32
elif dtype is float:
dtype = mstype.float32
else:
dtype = mstype.pytype_to_dtype(dtype)
if dtype not in dtype_tuple:
raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}")
return dtype
@constexpr
def _is_shape_empty(shp):
"""Check whether shape contains zero"""
if isinstance(shp, int):
return shp == 0
return F.shape_mul(shp) == 0
@constexpr
def _check_start_normalize(start, ndim):
"""check and normalize start argument for rollaxis."""
if start < -ndim or start > ndim:
raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
if start < 0:
start = start + ndim
return start
@constexpr
def _check_axes_range(axes, ndim):
"""
Check axes type and normalize the negative axes.
Args:
axes: Axes of the tensor.
ndim (int): The number of dimensions of the tensor.
Return:
Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
Raises:
TypeError: If the axes are not integer, tuple(int) or list(int).
ValueError: If duplicate axes exists or some axis is out of bounds.
"""
_check_axis_type(axes, True, True, True)
if isinstance(axes, (list, tuple)):
_check_element_int(axes)
axes = _canonicalize_axis(axes, ndim)
return axes
@constexpr
def _get_device():
"""Get the current device (`GPU`, `CPU`, `Ascend`)"""
return context.get_context('device_target')
@constexpr
def _infer_out_shape(*shapes):
"""
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
"""
shape_out = deque()
reversed_shapes = map(reversed, shapes)
for items in zip_longest(*reversed_shapes, fillvalue=1):
max_size = 0 if 0 in items else max(items)
if any(item not in (1, max_size) for item in items):
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
shape_out.appendleft(max_size)
return tuple(shape_out)
@constexpr
def _can_broadcast(*shapes):
"""
Returns Ture if shapes can broadcast, False if they cannot.
"""
try:
_infer_out_shape(*shapes)
except ValueError:
return False
finally:
pass
return True
@constexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
if not isinstance(axis, int):
raise TypeError(f'axes should be integers, not {type(axis)}')
if not -ndim <= axis < ndim:
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
return axis % ndim
@constexpr
def _check_axis_valid(axes, ndim):
"""
Checks axes are valid given ndim, and returns axes that can be passed
to the built-in operator (non-negative, int or tuple)
"""
if axes is None:
axes = F.make_range(ndim)
return axes
if isinstance(axes, (tuple, list)):
axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
if any(axes.count(el) > 1 for el in axes):
raise ValueError('duplicate value in "axis"')
return axes
return (_check_axis_in_range(axes, ndim),)
@constexpr
def _check_shape_aligned(shape1, shape2):
"""Checks shape1 and shape2 are valid shapes to perform inner product"""
if shape1[-1] != shape2[-1]:
raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
@constexpr
def _tile_size(shape, out_shape, ndim):
"""Returns tile_size such that shape*tile_size = out_shape"""
size = [1]*ndim
for idx, (i, j) in enumerate(zip(shape, out_shape)):
if i != j:
size[idx] = j
return tuple(size)
@constexpr
def _raise_type_error(info, param=None):
"""
Raise TypeError in both graph/pynative mode
Args:
info(str): info string to display
param(python obj): any object that can be recognized by graph mode. If is
not None, then param's type information will be extracted and displayed.
Default is None.
"""
if param is None:
raise TypeError(info)
raise TypeError(info + f"{type(param)}")
@constexpr
def _raise_value_error(info, param=None):
"""
Raise TypeError in both graph/pynative mode
Args:
info(str): info string to display
param(python obj): any object that can be recognized by graph mode. If is
not None, then param's value information will be extracted and displayed.
Default is None.
"""
if param is None:
raise ValueError(info)
raise ValueError(info + f"{param}")
@constexpr
def _raise_runtime_error(info, param=None):
"""
Raise RuntimeError in both graph/pynative mode
Args:
info(str): info string to display
param(python obj): any object that can be recognized by graph mode. If is
not None, then param's value information will be extracted and displayed.
Default is None.
"""
if param is None:
raise RuntimeError(info)
raise RuntimeError(info + f"{param}")
@constexpr
def _raise_unimplemented_error(info, param=None):
"""
Raise NotImplementedError in both graph/pynative mode
Args:
info(str): info string to display
param(python obj): any object that can be recognized by graph mode. If is
not None, then param's value information will be extracted and displayed.
Default is None.
"""
if param is None:
raise NotImplementedError(info)
raise NotImplementedError(info + f"{param}")
@constexpr
def _empty(dtype, shape):
"""Returns an uninitialized array with dtype and shape."""
return Tensor_(dtype, shape)
@constexpr
def _promote(dtype1, dtype2):
if dtype1 == dtype2:
return dtype1
if (dtype1, dtype2) in promotion_rule:
return promotion_rule[dtype1, dtype2]
return promotion_rule[dtype2, dtype1]
@constexpr
def _promote_for_trigonometric(dtype):
return rule_for_trigonometric[dtype]
@constexpr
def _max(*args):
"""Returns the maximum value."""
return max(*args)
@constexpr
def _min(*args):
""""Returns the minimum value."""
return min(*args)
@constexpr
def _abs(arg):
"""Returns the absolute value."""
return abs(arg)
@constexpr
def _check_same_type(dtype1, dtype2):
return dtype1 == dtype2
@constexpr
def _check_is_float(dtype):
"""Returns whether dtype is float16 or float32."""
return dtype in (mstype.float16, mstype.float32)
@constexpr
def _check_is_int(dtype):
return isinstance(dtype, typing.Int)
@constexpr
def _canonicalize_axis(axis, ndim):
"""
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
Args:
axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
ndim (int): The number of dimensions of the tensor.
Return:
Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
"""
if isinstance(axis, int):
axis = [axis]
for ax in axis:
_check_axis_in_range(ax, ndim)
def canonicalizer(ax):
return ax + ndim if ax < 0 else ax
axis = tuple([canonicalizer(axis) for axis in axis])
if all(axis.count(el) <= 1 for el in axis):
return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
raise ValueError(f"duplicate axes in {axis}.")
@constexpr
def _broadcast_tuples(tup1, tup2):
"""
Broadcast two 1D tuples to the same length, if inputs are ints, convert to
tuples first.
"""
tup1 = (tup1,) if isinstance(tup1, int) else tup1
tup2 = (tup2,) if isinstance(tup2, int) else tup2
if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
raise TypeError("input shift and axis must be tuple or list or int.")
if len(tup1) == len(tup2):
return tup1, tup2
if len(tup1) == 1:
tup1 *= len(tup2)
elif len(tup2) == 1:
tup2 *= len(tup1)
else:
raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
return tup1, tup2
@constexpr
def _expanded_shape(ndim, axis_size, axis):
"""
Returns a shape with size = 1 for all dimensions
except at axis.
"""
return tuple([axis_size if i == axis else 1 for i in range(ndim)])
@constexpr
def _add_unit_axes(shape, ndim, append=False):
"""
Prepends shape with 1s so that it has the number of dimensions ndim.
If append is set to True, returns shape appended with 1s instead.
"""
if isinstance(shape, int):
shape = (shape,)
ndim_diff = ndim - len(shape)
if ndim_diff > 0:
if append:
shape = [i for i in shape] + [1]*ndim_diff
else:
shape = [1]*ndim_diff + [i for i in shape]
return tuple(shape)
@constexpr
def _check_element_int(lst):
"""
Check whether each element in `lst` is an integer.
"""
for item in lst:
if not isinstance(item, int):
raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.")
return True
@constexpr
def _type_convert(force, obj):
"""
Convert type of `obj` to `force`.
"""
return force(obj)
@constexpr
def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
"""
Generates a new list/tuple by list comprehension.
Args:
obj (Union[int, list, tuple]):
If integer, it will be the length of the returned tuple/list.
item: The value to be filled. Default: None.
If None, the values in the new list/tuple are the same as obj
or range(obj) when obj is integer.
return_tuple(bool): If true, returns tuple, else returns list.
Returns:
List or tuple.
"""
res = []
lst = obj
if isinstance(obj, int):
lst = range(obj)
if make_none:
res = [None for _ in lst]
elif item is None:
res = [i for i in lst]
else:
res = [item for i in lst]
if return_tuple:
return tuple(res)
return res
@constexpr
def _tuple_setitem(tup, idx, value):
"""
Returns a tuple with specified `idx` set to `value`.
"""
tup = list(tup)
tup[idx] = value
return tuple(tup)
@constexpr
def _iota(dtype, num, increasing=True):
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
if num <= 0:
raise ValueError("zero shape Tensor is not currently supported.")
if increasing:
return Tensor(list(range(int(num))), dtype)
return Tensor(list(range(int(num)-1, -1, -1)), dtype)
@constexpr
def _ceil(number):
"""Ceils the number in graph mode."""
return math.ceil(number)
@constexpr
def _seq_prod(seq1, seq2):
"""Returns the element-wise product of seq1 and seq2."""
return tuple(map(lambda x, y: x*y, seq1, seq2))
@constexpr
def _make_tensor(val, dtype):
"""Returns the tensor with value `val` and dtype `dtype`."""
return Tensor(val, dtype)
@constexpr
def _tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]
@constexpr
def _isscalar(x):
"""Returns True if x is a scalar type"""
return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
typing.Bool, typing.String))
@constexpr
def _cumprod(x):
return tuple(accumulate(x, operator.mul))
@constexpr
def _in(x, y):
return x in y
@constexpr
def _callable_const(x):
"""Returns true if x is a function in graph mode."""
return isinstance(x, typing.Function)
@constexpr
def _check_is_inf(x, negative=False):
if not negative:
return x == float('inf')
return x == float('-inf')