import os
import re
import sys
PATTERN_FUNCTION = re.compile(r'ACL_FUNC_VISIBILITY\s+\n+.+\w+\([^();]*\);|.+\w+\([^();]*\);')
PATTERN_RETURN = re.compile(r'([^ ]+[ *])\w+\([^;]+;')
RETURN_STATEMENTS = {
'aclDataType': ' return ACL_DT_UNDEFINED;',
'aclFormat': ' return ACL_FORMAT_UNDEFINED;',
'aclError': ' printf("[ERROR]: stub library cannot be used for execution, please check your \
environment variables and compilation options to make sure you use the correct ACL library.\\n");\n \
return static_cast<aclError>(ACL_ERROR_COMPILING_STUB_MODE);',
'void': '',
'size_t': ' return static_cast<size_t>(0);',
'uint8_t': ' return static_cast<uint8_t>(0);',
'int32_t': ' return static_cast<int32_t>(0);',
'uint32_t': ' return static_cast<uint32_t>(0);',
'int64_t': ' return static_cast<int64_t>(0);',
'uint64_t': ' return static_cast<uint64_t>(0);',
'aclFloat16': ' return static_cast<aclFloat16>(0);',
'float': ' return 0.0f;',
'bool': ' return false;',
'double': ' return static_cast<double>(0.0f);',
}
def collect_header_files(cblas_inc_dir, op_compiler_inc_dir, op_exec_inc_dir, mdl_inc_dir):
"""input path,return relevant header files"""
cblas_headers = []
op_compiler_headers = []
op_exec_headers = []
mdl_headers = []
for root, dirs, files in os.walk(cblas_inc_dir):
files.sort()
for file in files:
if file.find("cblas") >= 0:
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
cblas_headers.append(file_path)
for root, dirs, files in os.walk(op_compiler_inc_dir):
files.sort()
for file in files:
if file.find("op_compiler") >= 0:
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
op_compiler_headers.append(file_path)
for root, dirs, files in os.walk(op_exec_inc_dir):
files.sort()
for file in files:
if file.find("acl_op.h") >= 0:
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
op_exec_headers.append(file_path)
for root, dirs, files in os.walk(mdl_inc_dir):
files.sort()
for file in files:
if file.find("mdl.h") >= 0:
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
mdl_headers.append(file_path)
return cblas_headers, op_compiler_headers, op_exec_headers, mdl_headers
def collect_functions(file_path):
signatures = []
with open(file_path) as f:
content = f.read()
matches = PATTERN_FUNCTION.findall(content)
for signature in matches:
signatures.append(signature)
return signatures
def implement_function(func):
function_def = func[:len(func) - 1]
function_def += '\n'
function_def += '{\n'
m = PATTERN_RETURN.search(func)
if m:
ret_type = m.group(1).strip()
if RETURN_STATEMENTS.__contains__(ret_type):
function_def += RETURN_STATEMENTS[ret_type]
else:
if ret_type.endswith('*'):
function_def += ' return nullptr;'
else:
raise RuntimeError("[ERROR] Unhandled return type: " + ret_type)
else:
function_def += ' return nullptr;'
function_def += '\n'
function_def += '}\n'
return function_def
def generate_stub_file(cblas_inc_dir, op_compiler_inc_dir, op_exec_inc_dir, mdl_inc_dir):
"""input inc_dir and return relevant contents"""
cblas_header_files, op_compiler_header_files, op_exec_header_files, mdl_header_files = collect_header_files(
cblas_inc_dir, op_compiler_inc_dir, op_exec_inc_dir, mdl_inc_dir)
print("header files has been generated")
cblas_content = generate_function(cblas_header_files, cblas_inc_dir)
print("cblas_content has been generate")
op_compiler_content = generate_function(op_compiler_header_files, op_compiler_inc_dir)
print("op_compiler_content has been generate")
op_exec_content = generate_function(op_exec_header_files, op_exec_inc_dir)
print("op_exec_content has been generate")
mdl_content = generate_function(mdl_header_files, mdl_inc_dir)
print("mdl_content has been generate")
return cblas_content, op_compiler_content, op_exec_content, mdl_content
def generate_function(header_files, inc_dir):
includes = []
includes.append('#include <stdio.h>\n')
includes.append('#include <stdint.h>\n')
for header in header_files:
if not header.endswith('.h'):
continue
include_str = '#include "{}"\n'.format(header[len(inc_dir):])
includes.append(include_str)
content = includes
print("include concent build success")
total = 0
content.append('\n')
for header in header_files:
if not header.endswith('.h'):
continue
content.append("// stub for {}\n".format(header[len(inc_dir):]))
functions = collect_functions(header)
print("inc file:{}, functions numbers:{}".format(header, len(functions)))
total += len(functions)
for func in functions:
content.append("{}\n".format(implement_function(func)))
content.append("\n")
print("implement concent build success")
print('total functions number is {}'.format(total))
return content
def gen_code(cblas_inc_dir, op_compiler_inc_dir, op_exec_inc_dir, mdl_inc_dir,
cblas_stub_path, op_compiler_stub_path, op_exec_stub_path, mdl_stub_path):
"""input inc_dir and relevant cpp files"""
if not cblas_inc_dir.endswith('/'):
cblas_inc_dir += '/'
if not op_compiler_inc_dir.endswith('/'):
op_compiler_inc_dir += '/'
if not op_exec_inc_dir.endswith('/'):
op_exec_inc_dir += '/'
if not mdl_inc_dir.endswith('/'):
mdl_inc_dir += '/'
cblas_content, op_compiler_content, op_exec_content, mdl_content = generate_stub_file(
cblas_inc_dir, op_compiler_inc_dir, op_exec_inc_dir, mdl_inc_dir)
print("cblas_content, op_compiler_content, op_exec_content, mdl_content have been generated")
with open(cblas_stub_path, mode='w') as f:
f.writelines(cblas_content)
with open(op_compiler_stub_path, mode='w') as f:
f.writelines(op_compiler_content)
with open(op_exec_stub_path, mode='w') as f:
f.writelines(op_exec_content)
with open(mdl_stub_path, mode='w') as f:
f.writelines(mdl_content)
if __name__ == '__main__':
cblas_inc_dir = sys.argv[1]
op_compiler_inc_dir = sys.argv[2]
op_exec_inc_dir = sys.argv[3]
mdl_inc_dir = sys.argv[4]
cblas_stub_file = sys.argv[5]
op_compiler_stub_file = sys.argv[6]
op_exec_stub_file = sys.argv[7]
mdl_stub_file = sys.argv[8]
gen_code(cblas_inc_dir, op_compiler_inc_dir, op_exec_inc_dir, mdl_inc_dir,
cblas_stub_file, op_compiler_stub_file, op_exec_stub_file, mdl_stub_file)