import os
import logging
import stat
import argparse
CANN_COPYRIGHT = '''/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
'''
TYPE_MAP = {
'SVector': 'AsdOps::SVECTOR_TYPE',
'vector': 'AsdOps::VECTOR_TYPE'
}
def __parse_param_file(operation_param_file: str):
struct_name = ''
mem_list = []
with open(operation_param_file) as fd:
lines = fd.readlines()
process_line = False
end_line = False
for org_line in lines:
line = org_line.strip()
line = line.strip(' ')
fields = line.split(' ')
if line.startswith('//'):
continue
if line.startswith('bool operator=='):
end_line = True
if end_line:
break
if line.startswith('struct'):
process_line = True
if not process_line:
continue
if line.startswith('enum'):
continue
if len(fields) == 3 and fields[0] == 'struct':
struct_name = fields[1]
elif line.find(';') != -1 and len(fields) >= 2:
mem_list.append([fields[0], fields[1].strip(';').strip('{0,')])
else:
pass
return struct_name, mem_list
def __parse_typename(typename: str):
raw_type = typename.split('<')[0].rsplit('::')[-1]
type_type = TYPE_MAP.get(raw_type, 'AsdOps::BASIC_TYPE')
is_vector = (raw_type in TYPE_MAP.keys())
if is_vector:
decayed_type = typename.split('<')[-1].rsplit('>')[0]
else:
decayed_type = typename
is_basic_type = not (decayed_type[0].isalpha() and decayed_type[0].isupper())
is_complete_type = decayed_type.startswith('Mki::') or is_basic_type
return (type_type, decayed_type, is_complete_type)
def __get_filtered_file_list(origin_file_list):
filtered_file_list = []
for file in origin_file_list:
if file == 'params.h' or file == 'common.h' or not file.endswith('.h'):
continue
else:
filtered_file_list.append(file)
return filtered_file_list
def __generate_utils_cpp(op_params_dir: str, namespace: str, dest_file_path: str):
op_func_list = []
for path, _, file_list in os.walk(op_params_dir):
filtered_files = __get_filtered_file_list(file_list)
if not filtered_files:
logging.error("cannot find param header file in directory: %s", os.path.realpath(op_params_dir))
exit(1)
for file in filtered_files:
op_func_list.append(__parse_param_file(os.path.join(path, file)))
with os.fdopen(os.open(dest_file_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
'w') as fd:
fd.write(CANN_COPYRIGHT)
fd.write('#include <cstddef>\n')
fd.write(f'#include "{namespace.lower()}/params/params.h"\n')
fd.write(f'#include "sink_common.h"\n')
for struct_name, mem_list in op_func_list:
fd.write(f'\ntemplate<> size_t AsdOps::GetOffset<{namespace}::OpParam::{struct_name}>'
'(size_t i, uint64_t &type)\n')
fd.write('{\n')
if len(mem_list) == 0:
fd.write(' (void)i;\n')
fd.write(' type = AsdOps::UNDEFINED_TYPE;\n')
fd.write(' return 0;\n')
else:
fd.write(f' static const size_t offsets[{len(mem_list)}] = {{\n')
for i, member in enumerate(mem_list):
domain = member[1].split('[', 1)[0]
fd.write(f' offsetof({namespace}::OpParam::{struct_name}, {domain}), // {i}\n')
fd.write(' };\n')
fd.write(f' static const uint64_t types[{len(mem_list)}] = {{\n')
for member in mem_list:
type_type, decayed_type, is_complete_type = __parse_typename(member[0])
type_name = decayed_type if is_complete_type else \
f'{namespace}::OpParam::{struct_name}::{decayed_type}'
fd.write(f' ({type_type} | (sizeof({type_name}) << AsdOps::SHIFT_BITS)),\n')
fd.write(' };\n')
fd.write(' type = types[i];\n')
fd.write(' return offsets[i];\n')
fd.write('}\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--namespace', type=str, required=True)
parser.add_argument('--params_path', type=str, required=True)
parser.add_argument('--dest_file_path', type=str, required=True)
input_args = parser.parse_args()
current_dir = os.path.split(os.path.realpath(__file__))[0]
__generate_utils_cpp(input_args.params_path, input_args.namespace, input_args.dest_file_path)