import os
import argparse
from codegen.gen import parse_native_yaml, FileManager
from codegen.utils import concatMap, PathManager
def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files')
parser.add_argument(
'--to_cpu', type=str, default="TRUE", help='move op which npu does not support to cpu')
parser.add_argument(
'-s',
'--source_yaml',
help='path to source yaml file containing operator external definitions')
parser.add_argument(
'--deprecate_yaml',
help='path to yaml file containing functions which is deprecated.')
parser.add_argument(
'-o', '--output_dir', help='output directory')
parser.add_argument(
'--dry_run', type=bool, default=False, help='output directory')
parser.add_argument(
'--version', type=str, default=None, help='pytorch version')
parser.add_argument(
'--impl_path', type=str, default=None, help='path to the source C++ file containing kernel definitions')
options = parser.parse_args()
source_yaml_path = os.path.realpath(options.source_yaml)
deprecate_yaml_path = os.path.realpath(options.deprecate_yaml)
PathManager.check_directory_path_readable(source_yaml_path)
PathManager.check_directory_path_readable(deprecate_yaml_path)
backend_declarations, dispatch_registrations_body = parse_native_yaml(source_yaml_path, deprecate_yaml_path)
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(
install_dir=install_dir, template_dir="codegen/templates", dry_run=False
)
fm = make_file_manager(options.output_dir)
pytorch_version = os.environ.get('PYTORCH_VERSION').split('.')
torch_dir = f"v{pytorch_version[0]}r{pytorch_version[1]}"
all_functions = sorted(set(concatMap(lambda f: [f],
set(v for sublist in backend_declarations.values() for v in sublist))))
fm.write_with_template(
"OpInterface.h",
"Interface.h",
lambda: {
"torch_dir": torch_dir,
"namespace": "op_plugin",
"declarations": all_functions,
},
)
header_files = {
"op_api": "OpApiInterface.h",
"acl_op": "AclOpsInterface.h",
"sparse": "SparseOpsInterface.h",
}
for op_type, file_name in header_files.items():
fm.write_with_template(
file_name,
"Interface.h",
lambda: {
"torch_dir": torch_dir,
"namespace": op_type,
"declarations": backend_declarations[op_type],
},
)
fm.write_with_template(
"OpInterface.cpp",
"OpInterface.cpp",
lambda: {
"namespace": "op_plugin",
"declarations": dispatch_registrations_body,
},
)
if __name__ == '__main__':
main()