import argparse
import os
import shutil
import sys
import re
import logging
class OpGenerator:
"""算子工程生成器"""
def __init__(self, op_type, op_name, output_path, template_variant):
self.op_type = op_type
self.op_name = op_name
self.output_path = output_path
self.template_name = "add_example"
self.script_dir = os.path.dirname(os.path.abspath(__file__))
if template_variant == "aicpu":
self.template_dir = os.path.abspath(os.path.join(self.script_dir, 'template', 'add_example_aicpu'))
else:
self.template_dir = os.path.abspath(os.path.join(self.script_dir, 'template', 'add_example'))
self.dest_dir = os.path.abspath(os.path.join(self.output_path, self.op_type, self.op_name))
def run(self):
"""执行生成流程"""
self._validate_inputs()
self._copy_template()
self._rename_files()
self._replace_content()
self._update_cmake_chain()
logging.info(f"成功为 {self.op_type}/{self.op_name} 创建算子工程!")
logging.info(f"工程路径: {self.dest_dir}")
def _validate_inputs(self):
"""校验输入参数的有效性和安全性"""
if not self.op_type or not self.op_name:
raise ValueError("算子类型和算子名称均不能为空。")
if not re.match(r"^[a-zA-Z0-9_]+$", self.op_type):
raise ValueError(f"算子类型 '{self.op_type}' 包含无效字符。只允许字母、数字和下划线。")
if not re.match(r"^[a-zA-Z0-9_]+$", self.op_name):
raise ValueError(f"算子名称 '{self.op_name}' 包含无效字符。只允许字母、数字和下划线。")
if os.path.exists(self.dest_dir):
raise FileExistsError(f"目标目录 '{self.dest_dir}' 已存在。")
def _copy_template(self):
"""复制模板文件到目标目录"""
logging.info(f"使用模板在 '{self.dest_dir}' 创建算子工程...")
if not os.path.exists(self.template_dir):
raise FileNotFoundError(f"找不到模板目录 '{self.template_dir}'。请确保模板目录存在。")
try:
shutil.copytree(self.template_dir, self.dest_dir)
except OSError as e:
raise OSError(f"复制模板文件失败: {e}") from e
def _rename_files(self):
"""重命名文件和目录中的占位符"""
for root, dirs, files in os.walk(self.dest_dir, topdown=False):
for name in files + dirs:
if self.template_name not in name:
continue
old_path = os.path.join(root, name)
new_name = name.replace(self.template_name, self.op_name)
new_path = os.path.join(root, new_name)
try:
os.rename(old_path, new_path)
except OSError as e:
raise OSError(f"重命名 '{old_path}' 到 '{new_path}' 失败: {e}") from e
@staticmethod
def _create_category_cmake(dir_path):
"""为新的分类目录创建CMakeLists.txt,使用与math/conversion/random相同的glob模式"""
cmake_lines = [
"file(GLOB SUBDIRECTORIES LIST_DIRECTORIES true",
" RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/*)",
"foreach(SUBDIR ${SUBDIRECTORIES})",
" if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${SUBDIR}/CMakeLists.txt)",
" add_subdirectory(${SUBDIR})",
" endif()",
"endforeach()",
]
cmake_file = os.path.join(dir_path, "CMakeLists.txt")
if not os.path.exists(cmake_file):
with open(cmake_file, 'w', encoding='utf-8') as f:
f.write("\n".join(cmake_lines) + "\n")
logging.info(f"Created CMakeLists.txt in {dir_path}")
@staticmethod
def _add_to_ops_category_list(category_list_file, child):
"""将新的分类添加到OPS_CATEGORY_LIST中"""
with open(category_list_file, 'r', encoding='utf-8') as f:
cat_content = f.read()
updated = re.sub(
r'(set\s*\(\s*OPS_CATEGORY_LIST\s+[^)]*)',
rf'\1\n "{child}"',
cat_content
)
with open(category_list_file, 'w', encoding='utf-8') as f:
f.write(updated)
logging.info(f"Added '{child}' to OPS_CATEGORY_LIST in {category_list_file}")
@staticmethod
def _check_included_in_cmake(content, cmake_file, child):
"""检查child是否已被CMakeLists.txt包含,返回(already_included, ops_category_file_path)"""
already_included = bool(re.search(
rf'add_subdirectory\s*\(\s*{re.escape(child)}\s*[\s\)]', content
))
ops_category_file_path = None
if already_included:
return True, None
cmake_dir = os.path.dirname(cmake_file)
files_to_check = [(content, None)]
for inc_match in re.finditer(r'include\s*\(\s*([^)\s]+)\s*\)', content):
inc_path = inc_match.group(1)
if not os.path.isabs(inc_path):
inc_path = os.path.join(cmake_dir, inc_path)
if os.path.exists(inc_path):
try:
with open(inc_path, 'r', encoding='utf-8') as inc_f:
files_to_check.append((inc_f.read(), inc_path))
except (IOError, OSError):
pass
for check_content, file_path in files_to_check:
match = re.search(r'set\s*\(\s*OPS_CATEGORY_LIST\s+([^)]+)\)', check_content)
if not match:
continue
categories = [c.strip('"') for c in match.group(1).split()]
if child in categories:
return True, None
if file_path is not None:
ops_category_file_path = file_path
return False, ops_category_file_path
@staticmethod
def _append_subdirectory(cmake_file, child):
"""在CMakeLists.txt末尾追加add_subdirectory(child)"""
with open(cmake_file, 'a', encoding='utf-8') as f:
f.write(f"\nadd_subdirectory({child})\n")
logging.info(f"Added add_subdirectory({child}) to {cmake_file}")
def _update_cmake_chain(self):
"""在父目录的CMakeLists.txt中添加add_subdirectory,确保构建系统能找到新目录"""
current = self.dest_dir
project_root = os.path.abspath(self.output_path)
while True:
parent = os.path.dirname(current)
child = os.path.basename(current)
parent_abs = os.path.abspath(parent)
cmake_file = os.path.join(parent, "CMakeLists.txt")
if os.path.exists(cmake_file):
with open(cmake_file, 'r', encoding='utf-8') as f:
content = f.read()
already_included, ops_category_file = self._check_included_in_cmake(
content, cmake_file, child
)
uses_glob = bool(re.search(r'file\s*\(\s*GLOB', content))
if not already_included and not uses_glob:
if ops_category_file:
self._add_to_ops_category_list(ops_category_file, child)
else:
self._append_subdirectory(cmake_file, child)
else:
if parent_abs != project_root and os.path.isdir(parent):
self._create_category_cmake(parent)
if parent_abs == project_root:
break
current = parent
def _replace_content_in_file(self, file_path, replacements):
"""Helper to replace content in a single file."""
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
except (IOError, OSError) as e:
logging.warning(f"读取文件 '{file_path}' 失败: {e}")
return
original_content = content
for old, new in replacements.items():
content = content.replace(old, new)
if content == original_content:
return
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
except (IOError, OSError) as e:
logging.warning(f"写入文件 '{file_path}' 失败: {e}")
def _replace_content(self):
"""替换文件内容中的占位符"""
op_name_capitalized = ''.join(word.capitalize() for word in self.op_name.split('_'))
template_name_capitalized = ''.join(word.capitalize() for word in self.template_name.split('_'))
replacements = {
self.template_name: self.op_name,
self.template_name.upper(): self.op_name.upper(),
template_name_capitalized: op_name_capitalized,
"add_example": self.op_name,
}
for root, _, files in os.walk(self.dest_dir):
for file in files:
if file.endswith(('.pyc', '.pyo')):
continue
file_path = os.path.join(root, file)
self._replace_content_in_file(file_path, replacements)
def execute(args):
"""根据命令行参数执行算子生成"""
generator = OpGenerator(
op_type=args.op_type,
op_name=args.op_name,
output_path=args.output_path,
template_variant=args.template_variant
)
generator.run()
def register_parser(subparsers):
"""为 opgen 命令注册解析器。"""
parser_opgen = subparsers.add_parser('opgen', help='生成项目骨架')
parser_opgen.add_argument('--op_type', '-t', required=True, help='算子分类,例如 math')
parser_opgen.add_argument('--op_name', '-n', required=True, help='新算子的名称,例如 asinh')
parser_opgen.add_argument('--output_path', '-p', default='.', help='生成工程的根路径')
parser_opgen.add_argument(
'--template_variant', '-v',
choices=['default', 'aicpu'], default='default', help='选择模板变种'
)
parser_opgen.set_defaults(func=execute)
def main():
"""主函数,用于独立执行"""
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s', stream=sys.stdout)
parser = argparse.ArgumentParser(description="生成项目骨架")
parser.add_argument('--op_type', '-t', required=True, help='算子分类,例如 math')
parser.add_argument('--op_name', '-n', required=True, help='新算子的名称,例如 asinh')
parser.add_argument('--output_path', '-p', default='.', help='生成工程的根路径')
parser.add_argument('--template_variant', '-v', choices=['default', 'aicpu'], default='default', help='选择模板变种')
args = parser.parse_args()
try:
execute(args)
except Exception as e:
logging.error(f"发生非预期的错误,退出。错误信息: {e}")
sys.exit(1)
if __name__ == "__main__":
main()