import re
import os
from typing import Tuple, List, Dict, Iterable, Iterator, Generic, Callable, Sequence, TypeVar, NoReturn, Optional
from enum import Enum
import contextlib
import textwrap
import sys
if sys.version_info >= (3,8):
from typing import Literal
else:
from typing_extensions import Literal
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader
try:
from yaml import CSafeDumper as Dumper
except ImportError:
from yaml import SafeDumper as Dumper
YamlDumper = Dumper
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
T = TypeVar("T")
S = TypeVar("S")
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
yield from func(x)
class YamlLoader(Loader):
def construct_mapping(self, node, deep=False):
mapping = []
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
if key in mapping:
raise ValueError(f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}")
mapping.append(key)
mapping = super().construct_mapping(node, deep=deep)
return mapping
Target = Enum('Target', (
'DEFINITION',
'DECLARATION',
'REGISTRATION',
'ANONYMOUS_DEFINITION',
'ANONYMOUS_DEFINITION_UNSUPPORT',
'NAMESPACED_DEFINITION',
'NAMESPACED_DECLARATION',
))
IDENT_REGEX = r'(^|\W){}($|\W)'
def split_name_params(schema: str) -> Tuple[str, List[str]]:
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
if m is None:
raise RuntimeError(f'Unsupported function schema: {schema}')
name, _, params = m.groups()
return name, params.split(', ')
T = TypeVar('T')
S = TypeVar('S')
def map_maybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
r = func(x)
if r is not None:
yield r
def concat_map(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
for r in func(x):
yield r
@contextlib.contextmanager
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
try:
yield
except Exception as e:
msg = msg_fn()
msg = textwrap.indent(msg, ' ')
msg = f'{e.args[0]}\n{msg}' if e.args else msg
e.args = (msg,) + e.args[1:]
raise
class NamespaceHelper:
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
e.g. for namespace_str torch::lazy,
prologue:
namespace torch {
namespace lazy {
epilogue:
} // namespace lazy
} // namespace torch
"""
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
cpp_namespaces = namespace_str.split("::")
if len(cpp_namespaces) > max_level:
raise ValueError(f"Codegen doesn't support more than {max_level} level(s)"
"of custom namespace. Got {namespace_str}.")
self.cpp_namespace_ = namespace_str
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
self.epilogue_ = "\n".join(
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
)
self.namespaces_ = cpp_namespaces
self.entity_name_ = entity_name
@staticmethod
def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2
) -> "NamespaceHelper":
"""
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
"""
names = namespaced_entity.split("::")
entity_name = names[-1]
namespace_str = "::".join(names[:-1])
return NamespaceHelper(
namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
)
@property
def prologue(self) -> str:
return self.prologue_
@property
def epilogue(self) -> str:
return self.epilogue_
@property
def entity_name(self) -> str:
return self.entity_name_
def get_cpp_namespace(self, default: str = "") -> str:
"""
Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
Return default if namespace string is empty.
"""
return self.cpp_namespace_ if self.cpp_namespace_ else default
class OrderedSet(Generic[T]):
storage: Dict[T, Literal[None]]
def __init__(self, iterable: Optional[Iterable[T]] = None):
if iterable is None:
self.storage = {}
else:
self.storage = {k: None for k in iterable}
def __contains__(self, item: T) -> bool:
return item in self.storage
def __iter__(self) -> Iterator[T]:
return iter(self.storage.keys())
def update(self, items: "OrderedSet[T]") -> None:
self.storage.update(items.storage)
def add(self, item: T) -> None:
self.storage[item] = None
def copy(self) -> "OrderedSet[T]":
ret: OrderedSet[T] = OrderedSet()
ret.storage = self.storage.copy()
return ret
@staticmethod
def union(*args: "OrderedSet[T]") -> "OrderedSet[T]":
ret = args[0].copy()
for s in args[1:]:
ret.update(s)
return ret
def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
return OrderedSet.union(self, other)
def __ior__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
self.update(other)
return self
def __eq__(self, other: object) -> bool:
if isinstance(other, OrderedSet):
return self.storage == other.storage
else:
return set(self.storage.keys()) == other
class PathManager:
@classmethod
def check_path_owner_consistent(cls, path: str):
"""
Function Description:
check whether the path belong to process owner
Parameter:
path: the path to check
Exception Description:
when invalid path, prompt the process owner.
"""
if not os.path.exists(path):
msg = f"The path does not exist: {path}"
raise RuntimeError(msg)
if os.stat(path).st_uid != os.getuid():
warnings.warn(f"Warning: The {path} owner does not match the current process.")
@classmethod
def check_directory_path_readable(cls, path):
"""
Function Description:
check whether the path is writable
Parameter:
path: the path to check
Exception Description:
when invalid data throw exception
"""
cls.check_path_owner_consistent(path)
if os.path.islink(path):
msg = f"Invalid path is a soft chain: {path}"
raise RuntimeError(msg)
if not os.access(path, os.R_OK):
msg = f"The path permission check failed: {path}"
raise RuntimeError(msg)
@classmethod
def remove_path_safety(cls, path: str):
if os.path.islink(path):
raise RuntimeError(f"Invalid path is a soft chain: {path}")
if os.path.exists(path):
os.remove(path)
def get_version(ver, all_version):
if ver is None:
return []
if isinstance(ver, list):
start_ver = all_version.index(ver[0])
end_ver = None if ver[1] == 'newest' else all_version.index(ver[1])
real_ver = all_version[start_ver:end_ver + 1] if end_ver is not None else all_version[start_ver:]
elif ver == 'all_version':
real_ver = all_version
else:
real_ver = ver.split(', ') if ', ' in ver else ver.split(',')
return real_ver