import os
import argparse
import yaml
from torchnpugen.gen import FileManager, parse_native_yaml_struct
from torchnpugen.struct.struct_codegen import parse_struct_yaml, gen_op_api
from torchnpugen.op_codegen_utils import PathManager
def main() -> None:
parser = argparse.ArgumentParser(description='Generate struct aclnn files')
parser.add_argument(
'-n',
'--native_yaml',
help='path to source yaml file containing operator external definitions')
parser.add_argument(
'--struct_yaml',
help='path to struct yaml file containing aclnn operators struct definitions')
parser.add_argument(
'-o', '--output_dir', help='output directory')
options = parser.parse_args()
env_aclnn_extension_switch = os.getenv('ACLNN_EXTENSION_SWITCH')
if env_aclnn_extension_switch:
script_dir = os.path.dirname(os.path.abspath(__file__))
template_dir_fm = os.path.join(script_dir, "templates")
else:
template_dir_fm = "torchnpugen/struct/templates"
fm = FileManager(
install_dir=options.output_dir, template_dir=template_dir_fm, dry_run=False
)
native_yaml_path = os.path.realpath(options.native_yaml)
PathManager.check_directory_path_readable(native_yaml_path)
with open(native_yaml_path, "r") as f:
es = yaml.safe_load(f)
native_functions = parse_native_yaml_struct(es)
struct_info = parse_struct_yaml(options.struct_yaml, native_functions)
gen_op_api(fm, struct_info, env_aclnn_extension_switch)
if __name__ == '__main__':
main()