import os
import argparse
from torchnpugen.gen import parse_native_yaml, FileManager, ENABLE_DVM
from torchnpugen.op_codegen_utils import concatMap, PathManager
def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files')
parser.add_argument(
'--to_cpu', type=str, default="TRUE", help='move op which npu does not support to cpu')
parser.add_argument(
'-s',
'--source_yaml',
help='path to source yaml file containing operator external definitions')
parser.add_argument(
'--deprecate_yaml',
help='path to yaml file containing functions which is deprecated.')
parser.add_argument(
'-o', '--output_dir', help='output directory')
parser.add_argument(
'--dry_run', type=bool, default=False, help='output directory')
parser.add_argument(
'--version', type=str, default=None, help='pytorch version')
parser.add_argument(
'--impl_path', type=str, default=None, help='path to the source C++ file containing kernel definitions')
options = parser.parse_args()
source_yaml_path = os.path.realpath(options.source_yaml)
deprecate_yaml_path = os.path.realpath(options.deprecate_yaml)
PathManager.check_directory_path_readable(source_yaml_path)
PathManager.check_directory_path_readable(deprecate_yaml_path)
backend_declarations, dispatch_registrations_body = parse_native_yaml(source_yaml_path, deprecate_yaml_path)
env_aclnn_extension_switch = os.getenv('ACLNN_EXTENSION_SWITCH')
if env_aclnn_extension_switch:
script_dir = os.path.dirname(os.path.abspath(__file__))
template_dir_make_file_manager = os.path.join(script_dir, "templates")
else:
template_dir_make_file_manager = "torchnpugen/templates"
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(
install_dir=install_dir, template_dir=template_dir_make_file_manager, dry_run=False
)
fm = make_file_manager(options.output_dir)
pytorch_version = os.environ.get('PYTORCH_VERSION').split('.')
torch_dir = f"v{pytorch_version[0]}r{pytorch_version[1]}"
all_functions = sorted(set(concatMap(lambda f: [f],
set(v for sublist in backend_declarations.values() for v in sublist))))
fm.write_with_template(
"OpInterface.h",
"Interface.h",
lambda: {
"torch_dir": torch_dir,
"namespace": "op_plugin",
"declarations": all_functions,
},
)
header_files = {
"op_api": "OpApiInterface.h",
"acl_op": "AclOpsInterface.h",
"sparse": "SparseOpsInterface.h",
}
if ENABLE_DVM:
header_files["lazy_fusion"] = "DvmOpsInterface.h"
for op_type, file_name in header_files.items():
fm.write_with_template(
file_name,
"Interface.h",
lambda: {
"torch_dir": torch_dir,
"namespace": op_type,
"declarations": backend_declarations[op_type],
},
)
dvm_includes = (
'#include "op_plugin/DvmOpsInterface.h"\n'
'#include "op_plugin/ops/dvm/lazy_fusion_kernel.h"\n'
) if ENABLE_DVM else ''
if env_aclnn_extension_switch:
includes_block = f'''#include "torch_npu/csrc/framework/interface/EnvVariables.h"
// #include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/core/npu/npu_log.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/SparseOpsInterface.h"
{dvm_includes}// #include "op_plugin/utils/op_log.h"
#include "op_plugin/OpInterface.h"
'''
else:
includes_block = f'''#include "torch_npu/csrc/framework/interface/EnvVariables.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/core/npu/npu_log.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/SparseOpsInterface.h"
{dvm_includes}#include "op_plugin/utils/op_log.h"
#include "op_plugin/OpInterface.h"
'''
fm.write_with_template(
"OpInterface.cpp",
"OpInterface.cpp",
lambda: {
"namespace": "op_plugin",
"declarations": dispatch_registrations_body,
"includes_block": includes_block,
},
)
if __name__ == '__main__':
main()