"""
生成 TilingData 桩
用于 UTest 场景下, 生成 Struct 表示的 TilingData 相关头文件.
"""
import argparse
import datetime
import logging
import os
import stat
from pathlib import Path
import regex as re
def check_if_new_tiling_file_path_existed(ori_file: Path) -> Path:
current_path = ori_file
target_dir = ori_file / "op_kernel"
while True:
if target_dir.is_dir():
break
parent_path = current_path.parent
if parent_path == current_path:
return ori_file, False
current_path = parent_path
target_dir = current_path / "op_kernel"
tiling_files = list(target_dir.glob("*tiling_data.h"))
if not tiling_files:
return ori_file, False
new_file = tiling_files[0]
new_path = target_dir / new_file.name
return new_path, True
def process_class_fields(fields_str, class_name):
result = []
seen = set()
arr_pattern = re.compile(
r"^\s*"
r"(?P<type>(?:[\w:<>]+\s+)*[\w:<>* &]+?)\s+"
r"(?P<name>\w+)\s*"
r"\[(?P<len>\d+)\]"
r"(?:\s*=\s*\{\s*[^\}]*\s*})?"
r"\s*;\s*"
r"(?:[ \t]*(?://[^\n]*)?)?$",
re.MULTILINE,
)
var_pattern = re.compile(r"^\s*(?P<type>[\w:<>]+)\s+(?P<name>\w+)\s*(?:=\s*[^;]*)?;", re.MULTILINE)
for line in fields_str.splitlines():
if not line.strip() or line.strip().startswith("//"):
continue
m = arr_pattern.match(line)
if m:
t, n, ln = m.group("type"), m.group("name"), m.group("len")
result.append(("array", t, n, ln))
seen.add(n)
continue
m = var_pattern.match(line)
if m:
t, n = m.group("type"), m.group("name")
if n not in seen:
result.append(("normal", t, n))
seen.add(n)
continue
return result
def find_classes(content):
out = []
class_re = re.compile(r"\bclass\s+(\w+)\s*{")
for m in class_re.finditer(content):
class_name = m.group(1)
start = m.end()
idx = start
braces = 1
while idx < len(content):
c = content[idx]
if c == "{":
braces += 1
elif c == "}":
braces -= 1
if braces == 0:
out.append((class_name, content[start:idx].strip()))
break
idx += 1
return out
def convert_template_tilingkey(ori_file: Path):
with open(ori_file) as f:
content = f.read()
classes = find_classes(content)
output = []
for class_name, fields_str in classes:
fields = process_class_fields(fields_str, class_name)
output.append(f"BEGIN_TILING_DATA_DEF({class_name})")
for entry in fields:
if entry[0] == "normal":
_, field_type, field_name = entry
if field_type in [
"uint32_t",
"int32_t",
"uint8_t",
"uint16_t",
"float",
"uint64_t",
"int64_t",
"double",
]:
output.append(f"TILING_DATA_FIELD_DEF({field_type}, {field_name});")
else:
output.append(f"TILING_DATA_FIELD_DEF_STRUCT({field_type}, {field_name});")
elif entry[0] == "array":
_, field_type, field_name, field_len = entry
output.append(f"TILING_DATA_FIELD_DEF_ARR({field_type}, {field_len}, {field_name});")
output.append("END_TILING_DATA_DEF;")
output.append(f"REGISTER_TILING_DATA_CLASS({class_name}Op, {class_name})\n")
result_code = "\n".join(output)
return result_code
def process_fields(fields_str, struct_name):
field_pattern = re.compile(r"(\w+)\s+(\w+)(?:\s*=\d+)?;")
fields = field_pattern.findall(fields_str)
return fields
def convert_to_old_tiling_struct_style(redirected_file_path):
with open(redirected_file_path) as f:
content = f.read()
struct_pattern = re.compile(r"struct (\w+) {([^}]*)}", re.DOTALL)
structs = struct_pattern.findall(content)
output = []
for struct_name, fields_str in structs:
fields = process_fields(fields_str, struct_name)
output.append(f"BEGIN_TILING_DATA_DEF({struct_name})")
for field_type, field_name in fields:
if field_type in ["uint32_t", "uint8_t", "uint16_t"]:
output.append(f"TILING_DATA_FIELD_DEF({field_type}, {field_name});")
else:
output.append(f"TILING_DATA_FIELD_DEF_STRUCT({field_type}, {field_name});")
output.append("END_TILING_DATA_DEF;")
output.append(f"REGISTER_TILING_DATA_CLASS({struct_name}Op, {struct_name})\n")
result_code = "\n".join(output)
return result_code
class Process:
_WRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
_WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
@classmethod
def _write_file(cls, file: Path, src: str):
with os.fdopen(os.open(file, cls._WRITE_FLAGS, cls._WRITE_MODES), "w") as fh:
fh.write(src)
@classmethod
def _get_begin_source(cls, ori_file: Path, gen_file: Path) -> str:
bgn_src: str = (
"/**\n"
" * This program is free software, you can redistribute it and/or modify.\n"
" * Copyright (c) {year} Huawei Technologies Co., Ltd.\n"
" * This file is a part of the CANN Open Software.\n"
' * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").\n'
" * Please refer to the License for details. "
"You may not use this file except in compliance with the License.\n"
' * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, '
"WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, "
"INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.\n"
" * See LICENSE in the root of the software repository for the full text of the License.\n"
" */\n"
).format(year=datetime.datetime.today().year)
bgn_src += "\n"
bgn_src += ("/*!\n * \\file {gen_file_name}\n * \\brief Generate {ori_file_name}\n */\n").format(
gen_file_name=gen_file.name, ori_file_name=ori_file.name
)
bgn_src += "\n"
bgn_src += "#pragma once\n"
bgn_src += "\n"
return bgn_src
@classmethod
def _get_tiling_source(cls, ori_file: Path, isTemplateTilingKey: bool = False) -> str:
"""
获取 TilingData 定义源码
:param ori_file: 原始文件
:return: 生成文件内容
"""
rst_source = (
"#include <cstdint>\n#include <cstring>\n#include <securec.h>\n#include <kernel_tiling/kernel_tiling.h>\n\n"
)
pattern = re.compile(r"[(](.*)[)]", re.S)
if isTemplateTilingKey:
lines = convert_template_tilingkey(ori_file)
lines = lines.splitlines()
else:
ori_file, existed_flag = check_if_new_tiling_file_path_existed(ori_file)
if existed_flag:
lines = convert_to_old_tiling_struct_style(ori_file)
lines = lines.splitlines()
else:
with open(ori_file) as fd:
lines = fd.readlines()
for line in lines:
line = line.strip()
struct_src = ""
if line.startswith("BEGIN_TILING_DATA_DEF"):
struct_name = re.findall(pattern, line)[0]
struct_src += ("#pragma pack(1)\nstruct {}\n").format(struct_name)
struct_src += "{\n"
struct_offset = 0
elif line.startswith("TILING_DATA_FIELD_DEF_ARR"):
field_params = re.findall(pattern, line)[0]
fds = field_params.split(",")
fds_dtype = fds[0].strip()
fds_num = int(fds[1].strip())
fds_name = fds[2].strip()
tmp_src, tmp_offset = cls._get_tmp_src(
offset=struct_offset, dtype=fds_dtype, name=fds_name, num=fds_num
)
struct_src += tmp_src
struct_offset += tmp_offset
elif line.startswith("TILING_DATA_FIELD_DEF_STRUCT"):
field_params = re.findall(pattern, line)[0]
fds = field_params.split(",")
struct_src += " {} {};\n".format(fds[0].strip(), fds[1].strip())
elif line.startswith("TILING_DATA_FIELD_DEF"):
field_params = re.findall(pattern, line)[0]
fds = field_params.split(",")
fds_dtype = fds[0].strip()
fds_num = 1
fds_name = fds[1].strip()
tmp_src, tmp_offset = cls._get_tmp_src(
offset=struct_offset, dtype=fds_dtype, name=fds_name, num=fds_num
)
struct_src += tmp_src
struct_offset += tmp_offset
elif line.startswith("END_TILING_DATA_DEF"):
if struct_offset % 8 != 0:
pad_num = 8 - (struct_offset % 8)
struct_src += " uint8_t {}_PH[{}] = {{}};\n".format(struct_name, pad_num)
struct_offset += pad_num
struct_src += "};"
struct_src += "\n"
struct_src += "#pragma pack()\n"
struct_src += "\n"
struct_src += "inline void Init{struct_name}(uint8_t* tiling, {struct_name}* const_data)\n".format(
struct_name=struct_name
)
struct_src += "{\n"
struct_src += (
" (void)memcpy_s(const_data, sizeof({struct_name}), tiling, sizeof({struct_name}));\n".format(
struct_name=struct_name
)
)
struct_src += "}\n"
struct_src += "\n"
rst_source += struct_src
rst_source += (
""
"#undef GET_TILING_DATA\n"
"#define GET_TILING_DATA(tiling_data, tiling_arg) \\\n"
"{struct_name} tiling_data; \\\n"
"Init{struct_name}(tiling_arg, &tiling_data)\n"
"\n"
).format(struct_name=struct_name)
return rst_source
@classmethod
def _get_tiling_whole(cls, ori_file: Path, isTemplateTilingKey: bool = False) -> str:
with open(ori_file) as f:
content = f.read()
return content
@classmethod
def _gen_tiling_h(cls, ori_file: Path, gen_dir: Path):
gen_file = Path(gen_dir, "_gen_" + ori_file.name)
flag = "op_kernel" in [part for part in ori_file.parts]
if not gen_file.exists():
if not flag:
bgn_src = cls._get_begin_source(ori_file=ori_file, gen_file=gen_file)
def_src = cls._get_tiling_source(ori_file=ori_file, isTemplateTilingKey=flag)
source = bgn_src + def_src
else:
source = "\n"
cls._write_file(file=gen_file, src=source)
logging.info("Generate TilingDefFile: %s", gen_file)
return gen_file
@classmethod
def _get_type_size(cls, dtype: str):
mp = {
"int8_t": 1,
"int16_t": 2,
"int32_t": 4,
"int64_t": 8,
"uint8_t": 1,
"uint16_t": 2,
"uint32_t": 4,
"uint64_t": 8,
"float": 4,
}
d_len = mp.get(dtype)
if d_len is None:
raise ValueError(f"Unknown dtype({dtype})")
return d_len
@classmethod
def _get_tmp_src(cls, offset: int, dtype: str, name: str, num: int):
source = ""
result = 0
dtype_size = cls._get_type_size(dtype=dtype)
if offset % dtype_size != 0:
pad_num = dtype_size - (offset % dtype_size)
source += " uint8_t {}_PH[{}] = {{}};\n".format(name, pad_num)
result += pad_num
if num == 1:
source += " {} {} = 0;\n".format(dtype, name)
else:
source += " {} {}[{}] = {{}};\n".format(dtype, name, num)
result += cls._get_type_size(dtype=dtype) * num
return source, result
@classmethod
def gen_tiling_h(cls, ori_files: list[Path], gen_dir: Path):
gen_files: list[Path] = []
gen_dir.mkdir(parents=True, exist_ok=True)
for ori_file in ori_files:
if not ori_file.exists():
raise ValueError(f"Origin file({ori_file}) not exist.")
gen_file = cls._gen_tiling_h(ori_file=ori_file, gen_dir=gen_dir)
gen_files.append(gen_file)
return gen_files
@classmethod
def gen_tiling_data_h(cls, op: str, gen_files: list[Path], data_file: Path):
if not data_file.exists():
bgn_src = cls._get_begin_source(ori_file=data_file, gen_file=data_file)
def_src = ""
for gen_f in gen_files:
def_src += '#include "tiling/{op}/{file_name}"\n'.format(op=op, file_name=gen_f.name)
source = bgn_src + def_src
cls._write_file(file=data_file, src=source)
logging.info("Generate TilingDataFile: %s", data_file)
return data_file
@classmethod
def gen_tiling_stub_h(cls, data_file: Path, stub_file: Path):
if not stub_file.exists():
bgn_src = cls._get_begin_source(ori_file=stub_file, gen_file=stub_file)
def_src = ""
def_src += '#include "{}"\n'.format(data_file.name)
def_src += (
"\n"
"#undef GET_TILING_DATA_WITH_STRUCT\n"
"#define GET_TILING_DATA_WITH_STRUCT(tiling_struct, tiling_data, tiling_arg) \\\n"
"tiling_struct tiling_data; \\\n"
"(void)memcpy_s(&tiling_data, sizeof(tiling_struct), tiling_arg, sizeof(tiling_struct));\n"
"\n"
)
def_src += (
"\n"
"#undef GET_TILING_DATA_MEMBER\n"
"#define GET_TILING_DATA_MEMBER(tiling_type, member, var, tiling) \\\n"
"decltype(tiling_type::member) var; \\\n"
"size_t offset##var = (size_t)(&((tiling_type *)0)->member); \\\n"
"(void)memcpy_s(&var, sizeof(decltype(var)), tiling + offset##var, sizeof(decltype(var))); \n"
)
source = bgn_src + def_src
cls._write_file(file=stub_file, src=source)
logging.info("Generate TilingStubFile: %s", stub_file)
return stub_file
@classmethod
def main(cls):
parser = argparse.ArgumentParser(description="TilingData Generator", epilog="Best Regards!")
parser.add_argument("-o", "--operator", required=True, nargs=1, type=str, help="Target operator.")
parser.add_argument(
"-s",
"--srcs",
required=True,
action="append",
nargs="+",
type=Path,
help="Origin tiling data define files(.h).",
)
parser.add_argument("-d", "--dest", required=True, nargs=1, type=Path, help="Generate directory.")
result = parser.parse_args()
op = result.operator[0].lower()
ori_files: list[Path] = []
for file in result.srcs:
ori_files.append(file[0].absolute())
gen_dir = Path(result.dest[0], "tiling/{}".format(op)).absolute()
data_file = Path(gen_dir, "tiling_data.h")
stub_file = Path(gen_dir, "tiling_stub.h")
gen_files = cls.gen_tiling_h(ori_files=ori_files, gen_dir=gen_dir)
cls.gen_tiling_data_h(op=op, gen_files=gen_files, data_file=data_file)
cls.gen_tiling_stub_h(data_file=data_file, stub_file=stub_file)
if __name__ == "__main__":
logging.basicConfig(format="%(filename)s:%(lineno)d [%(levelname)s] %(message)s", level=logging.DEBUG)
try:
Process.main()
except Exception as e:
logging.error(e)
raise e