import os
import stat
import functools
import hashlib
from typing import (List, Dict, Optional, Set, Callable, Any,
Union, Sequence, TypeVar, Iterable)
from collections import defaultdict
from codegen.code_template import CodeTemplate
from codegen.model import (FunctionSchema, NativeFunction,
NativeFunctionsGroup, OperatorName,
SchemaKind, assert_never)
from codegen.utils import concat_map, PathManager
T = TypeVar('T')
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
func_map: Dict[OperatorName, NativeFunction] = {}
for f in funcs:
func_map[f.func.name] = f
for f in funcs:
if f.structured_delegate is not None:
delegate_func = func_map.get(f.structured_delegate)
if not delegate_func.structured:
raise ValueError(f"{f.func.name} is marked as a structured_delegate pointing to " \
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " \
f"Consider adding 'structured=True' to the delegated operator")
def cpp_string(s: str) -> str:
"""Convert a python string into a c++ string literal """
s = s.replace('\\', '\\\\')
s = s.replace('"', '\\"')
s = s.replace('\a', '\\a')
s = s.replace('\b', '\\b')
s = s.replace('\f', '\\f')
s = s.replace('\n', '\\n')
s = s.replace('\v', '\\v')
s = s.replace('\t', '\\t')
return f'"{s}"'
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)
def string_stable_hash(s: str) -> int:
sha1 = hashlib.sha256(s.encode('latin1')).digest()
return int.from_bytes(sha1, byteorder='little')
class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
self.template_dir = template_dir
self.filenames = set()
self.dry_run = dry_run
@staticmethod
def _write_if_changed(filename: str, contents: str) -> None:
old_contents: Optional[str]
filepath = os.path.realpath(filename)
try:
with open(filepath, 'r') as f:
old_contents = f.read()
except IOError:
old_contents = None
if contents != old_contents:
PathManager.remove_path_safety(filepath)
with os.fdopen(os.open(filepath, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), "w") as f:
f.write(contents)
os.chmod(filepath, stat.S_IRUSR | stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
def write_with_template(self, filename: str, template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
filename = '{}/{}'.format(self.install_dir, filename)
if filename in self.filenames:
raise ValueError(f"duplicate file write {filename}")
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
if 'generated_comment' not in env:
comment = "@" + "generated by tools/codegen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env['generated_comment'] = comment
env['legacy_th_headers'] = []
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
self.write_with_template(filename, filename, env_callable)
def write_sharded(
self,
filename: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str]
) -> None:
everything: Dict[str, Any] = {'shard_id': 'Everything'}
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
all_shards = [everything] + shards
if base_env is not None:
for shard in all_shards:
shard.update(base_env)
for key in sharded_keys:
for shard in all_shards:
if key in shard:
if not isinstance(shard[key], list):
raise TypeError("sharded keys in base_env must be a list.")
shard[key] = shard[key].copy()
else:
shard[key] = []
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
if k not in sharded_keys:
raise KeyError(f"undeclared sharded key {k}")
into[k] += v
for item in items:
key = key_fn(item)
sid = string_stable_hash(key) % num_shards
env = env_callable(item)
merge_env(shards[sid], env)
merge_env(everything, env)
dot_pos = filename.rfind('.')
if dot_pos == -1:
dot_pos = len(filename)
base_filename = filename[:dot_pos]
extension = filename[dot_pos:]
for shard in all_shards:
shard_id = shard['shard_id']
self.write_with_template(f"{base_filename}{shard_id}{extension}",
filename,
lambda: shard)
self.filenames.discard(
f"{self.install_dir}/{base_filename}Everything{extension}")
def write_outputs(self, filename: str) -> None:
"""Write a file containing the list of all outputs which are
generated by this script.
"""
self._write_if_changed(
filename,
''.join(name + ";" for name in sorted(self.filenames)))
def get_grouped_native_functions(
native_functions: Sequence[NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]] = defaultdict(dict)
for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()]
if f.func.kind() in d:
raise ValueError("f.func.kind() should not be in d.")
d[f.func.kind()] = f
def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
return list(d.values())
else:
return [r]
return list(concat_map(flatten_pre_group, list(pre_grouped_native_functions.values())))