"""Patterns for describing graphs"""
from mindspore.ops import Primitive
from mindspore.common.tensor import Tensor
from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm
__all__ = [
"OneOf",
"Prim",
"Call",
"NoneOf",
"Any",
"NewTensor",
"NewParameter",
"Imm"
]
class OneOf(OneOf_):
r"""
Express a pattern which allows a list of patterns.
"""
def __init__(self, patterns=None):
r"""
Args:
patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance.
Raises:
TypeError: raise type error for invalid inputs.
"""
self.patterns = patterns
if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns])
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
OneOf_.__init__(self, patterns)
else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_):
r"""
Express a pattern of certain primitive type(s).
NOTE:
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
"""
def __init__(self, types, name=None):
r"""
Args:
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
tuple[:class:`mindspore.ops.Primitive`]):
Specify allowed types.
If it is a string, the form could be
1) a single primitive type, e.g. 'Conv2D'
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None.
Raises:
TypeError: raise type error for invalid argument.
"""
if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}")
self.name = name
if isinstance(types, str):
if self.name is None:
self.name = types
self.types = types.split('|')
elif isinstance(types, Primitive):
if self.name is None:
self.name = types.name
self.types = [types]
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
if self.name is None:
self.name = ""
for prim in types:
self.name += prim.name
self.types = types
else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
Prim_.__init__(self, self.types, self.name)
class Call(Call_):
r"""
Express a primitive CNode.
"""
def __init__(self, prim_pattern, inputs=None):
r"""
Args:
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
:class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance.
Raises:
TypeError: raise type error for invalid argument.
"""
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
self.prim_pattern = prim_pattern
self.inputs = []
if inputs is None:
pass
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
self.inputs = inputs
else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
Call_.__init__(self, self.prim_pattern, self.inputs)
class NoneOf(NoneOf_):
r"""
Express a pattern which forbids a list of patterns.
NOTE:
NoneOf pattern should not be the root pattern.
"""
def __init__(self, patterns=None):
r"""
Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each
element should be one of the exposed Pattern instance.
Raises:
TypeError: raise type error for invalid argument.
"""
self.patterns = patterns
if patterns is None:
NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern):
NoneOf_.__init__(self, [patterns])
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
NoneOf_.__init__(self, patterns)
else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_):
r"""
New Tensor to be used in the target.
"""
def __init__(self, input_tensor):
r"""
Args:
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target.
Raises:
TypeError: raise type error for invalid argument.
"""
self.input_tensor = input_tensor
if isinstance(input_tensor, Tensor):
NewTensor_.__init__(self, input_tensor)
else:
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
"""
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
r"""
Args:
para_name(str): name for the new Parameter.
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter.
requires_grad(bool): True if the parameter requires gradient. Default: True.
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False.
Raises:
TypeError: raise type error for invalid argument.
"""
self.para_name = para_name
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool):
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel)
else:
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool), got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}")