import os
import re
import sys
import logging
"""
generate stub func body by return type
"""
RETURN_STATEMENTS = {
'ge::graphStatus':
' std::cout << "[ERROR]: stub library libop_common cannot be used for execution, please check your "\n'
' << "environment variables and compilation options to make sure you use the correct library."\n'
' << std::endl;\n'
' return ge::GRAPH_FAILED;'
}
"""
white_list_for_debug, include_dir_key_words is to
determines which header files to generate cc files from
when DEBUG on
"""
white_list_for_debug = ["common_infershape_fns.h"]
include_dir_key_words = ["op_common"]
"""
this attr is used for symbol table visible
"""
GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'
"""
max code len per line in hua_wei software programming specifications
"""
MAX_CODE_LEN_PER_LINE = 100
DEBUG = True
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
level=logging.INFO)
def need_generate_func(func_line):
"""
:param func_line:
:return:
"""
if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
return False
return True
def file_endswith_white_list_suffix(file):
"""
:param file:
:return:
"""
if DEBUG:
for suffix in white_list_for_debug:
suffix = re.sub(r'^/*', '/', suffix)
if file.endswith(suffix):
return True
return False
else:
return True
"""
belows are patterns used for analyse .h file
"""
pattern_func = re.compile(r"""(^[\s]*)([a-zA-Z~_].*[)](?!.*{).*)(;.*)\n$""", re.VERBOSE | re.MULTILINE | re.DOTALL)
pattern_comment = re.compile(r'^\s*//')
pattern_comment_2_start = re.compile(r'^\s*/[*]')
pattern_comment_2_end = re.compile(r'[*]/\s*$')
pattern_define = re.compile(r'^\s*#define')
pattern_define_return = re.compile(r'\\\s*$')
pattern_static_assert = re.compile(r'^\s*static_assert')
pattern_static_assert_return = re.compile(r'\);\s*$')
pattern_blank_line = re.compile(r'^\s*$')
pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
pattern_template = re.compile(r'^\s*template')
pattern_template_end = re.compile(r'>\s*$')
pattern_namespace = re.compile(r'namespace.*{')
pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
pattern_start = re.compile('{')
pattern_end = re.compile('}')
pat_format_func = re.compile(r"""^(
(?:const[ ]+)?
(?:
[:\w]+
|
std::(?:vector|shared_ptr)<[:\w ]+>
|
std::(?:vector|shared_ptr)<std::vector<[:\w ]+>>
|
std::(?:map|unordered_map|pair)<[:\w]+[, ]+[:\w]+>
)
)
([ ]+)
([&*]+)""", re.VERBOSE)
pat_search_func = re.compile(r"""^(?:const[ ]+)?
(?P<ret_type>
(?:
[:\w]+
|
std::(?:vector|shared_ptr)<[:\w ]+>
|
std::(?:vector|shared_ptr)<std::vector<[:\w ]+>>
|
std::(?:map|unordered_map|pair)<[:\w]+[, ]+[:\w]+>
)
(?:[&*]+)?
)
[ ]+
(?P<class_name>\w+)
::
\n?
(?P<func_name>\w+|operator=)
[ ]*
\(""", re.VERBOSE)
class H2CC(object):
def __init__(self, input_file, output_file, shared_includes_content):
"""
:param input_file:
:param output_file:
:param shared_includes_content:
"""
self.input_file = input_file
self.output_file = output_file
self.shared_includes_content = shared_includes_content
self.line_index = 0
self.input_fd = open(self.input_file, 'r')
self.input_content = self.input_fd.readlines()
self.output_fd = open(self.output_file, 'w')
self.stack = []
self.stack_class = []
self.stack_template = []
self.func_list_exist = []
def __del__(self):
self.input_fd.close()
self.output_fd.close()
del self.stack
del self.stack_class
del self.stack_template
del self.func_list_exist
def just_skip(self):
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
self.input_content[self.line_index]):
self.line_index += 1
if pattern_comment_2_start.search(self.input_content[self.line_index]):
while not pattern_comment_2_end.search(self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1
if pattern_define.search(self.input_content[self.line_index]):
while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1
if pattern_static_assert.search(self.input_content[self.line_index]):
while not pattern_static_assert_return.search(self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1
def write_inc_content(self):
for shared_include_content in self.shared_includes_content:
self.output_fd.write(shared_include_content)
def h2cc(self):
"""
:return:
"""
logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
self.write_inc_content()
while self.line_index < len(self.input_content):
self.just_skip()
self.handle_namespace()
template_string = self.handle_template()
line = self.input_content[self.line_index]
match_class = pattern_class.search(line)
match_start = pattern_start.search(line)
handle_class_result = self.handle_class(template_string, line, match_start, match_class)
if handle_class_result == "continue":
continue
handle_stack_result = self.handle_stack(match_start)
if handle_stack_result == "continue":
continue
handle_func1_result, line, start_i = self.handle_func1(line)
if handle_func1_result == "continue":
continue
line = pattern_keyword.sub('', line)
logging.info("line[%s]", line)
friend_match = re.search('friend ', line)
if len(self.stack_class) > 0 and not friend_match:
line, func_name = self.handle_class_member_func(line, template_string)
else:
line, func_name = self.handle_normal_func(line, template_string)
need_generate = need_generate_func(line)
line += self.implement_function(line)
line = self.gen_comment(start_i) + line
self.write_func_content(line, func_name, need_generate)
self.line_index += 1
logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see %s', self.output_file)
def handle_func1(self, line):
"""
:param line:
:return:
"""
find1 = re.search('[(]', line)
if not find1:
self.line_index += 1
return "continue", line, None
find2 = re.search('[)]', line)
start_i = self.line_index
space_match = pattern_leading_space.search(line)
if find1 and (not find2):
self.line_index += 1
line2 = self.input_content[self.line_index]
if space_match:
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2
while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
self.line_index += 1
line2 = self.input_content[self.line_index]
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2
match_start = pattern_start.search(self.input_content[self.line_index])
match_end = pattern_end.search(self.input_content[self.line_index])
if match_start:
if not match_end:
self.stack.append('normal_now')
ii = start_i
while ii <= self.line_index:
ii += 1
self.line_index += 1
return "continue", line, start_i
logging.info("line[%s]", line)
(line, match) = pattern_func.subn(r'\2\n', line)
logging.info("line[%s]", line)
if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
line = self.input_content[start_i - 1] + line
line = line.lstrip()
if not match:
self.line_index += 1
return "continue", line, start_i
return "pass", line, start_i
def handle_stack(self, match_start):
"""
:param match_start:
:return:
"""
line = self.input_content[self.line_index]
match_end = pattern_end.search(line)
if match_start:
self.stack.append('normal_now')
if match_end:
top_status = self.stack.pop()
if top_status == 'namespace_now':
self.output_fd.write(line + '\n')
elif top_status == 'class_now':
self.stack_class.pop()
self.stack_template.pop()
if match_start or match_end:
self.line_index += 1
return "continue"
if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
self.line_index += 1
return "continue"
return "pass"
def handle_class(self, template_string, line, match_start, match_class):
"""
:param template_string:
:param line:
:param match_start:
:param match_class:
:return:
"""
if not match_class:
return "pass"
self.stack_template.append(template_string)
self.stack.append('class_now')
class_name = match_class.group(3)
if '<' in class_name:
k = line.index('<')
fit = 1
for ii in range(k + 1, len(line)):
if line[ii] == '<':
fit += 1
if line[ii] == '>':
fit -= 1
if fit == 0:
break
class_name += line[k + 1:ii + 1]
logging.info('class_name[%s]', class_name)
self.stack_class.append(class_name)
while not match_start:
self.line_index += 1
line = self.input_content[self.line_index]
match_start = pattern_start.search(line)
self.line_index += 1
return "continue"
def handle_template(self):
line = self.input_content[self.line_index]
match_template = pattern_template.search(line)
template_string = ''
if match_template:
match_template_end = pattern_template_end.search(line)
template_string = line
while not match_template_end:
self.line_index += 1
line = self.input_content[self.line_index]
template_string += line
match_template_end = pattern_template_end.search(line)
self.line_index += 1
return template_string
def handle_namespace(self):
line = self.input_content[self.line_index]
match_namespace = pattern_namespace.search(line)
if match_namespace:
self.output_fd.write(line + '\n')
self.stack.append('namespace_now')
self.line_index += 1
def handle_normal_func(self, line, template_string):
template_line = ''
self.stack_template.append(template_string)
if self.stack_template[-1] != '':
template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = re.sub(r'\s*=.*,', ',', line)
line = re.sub(r'\s*=.*\)', ')', line)
line = template_line + line
self.stack_template.pop()
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name
def handle_class_member_func(self, line, template_string):
template_line = ''
x = ''
if template_string != '':
template_string = re.sub(r'\s*template', 'template', template_string)
template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
template_string = re.sub(r'\s*=.*,', ',', template_string)
template_string = re.sub(r'\s*=.*', '', template_string)
if self.stack_template[-1] != '':
if not (re.search(r'<\s*>', stack_template[-1])):
template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
if not (re.search(r'<.*>', self.stack_class[-1])):
x = re.sub(r'template\s*<', '<', template_line)
x = re.sub(r'\n', '', x)
x = re.sub(r'\s*=.*,', ',', x)
x = re.sub(r'\s*=.*\>', '>', x)
x = x.rstrip()
x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
x)
x = re.sub(r'<\s+', '<', x)
x = re.sub(r'\s+>', '>', x)
x = re.sub(r'\s+,', ',', x)
x = re.sub(r',\s+', ', ', x)
line = re.sub(r'\s*=\s+0', '', line)
line = re.sub(r'\s*=\s+.*,', ',', line)
line = re.sub(r'\s*=\s+.*\)', ')', line)
logging.info("x[%s]\nline[%s]", x, line)
rep_fmt = '%s%s::{}%s' % (self.stack_class[-1], x, r'\1(')
temp_line = pattern_func_name.sub(rep_fmt.format(''), line, count=1)
if len(temp_line) > MAX_CODE_LEN_PER_LINE:
line = pattern_func_name.sub(rep_fmt.format('\n'), line, count=1)
else:
line = temp_line
logging.info("line[%s]", line)
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = template_line + template_string + line
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name
def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)
def gen_comment(self, start_i):
comment_line = ''
k = start_i - 1
if pattern_template.search(self.input_content[k]):
k -= 1
if pattern_comment_2_end.search(self.input_content[k]):
comment_line = self.input_content[k].lstrip()
while not pattern_comment_2_start.search(self.input_content[k]):
k -= 1
comment_line = self.input_content[k].lstrip() + comment_line
else:
for j in range(k, 0, -1):
c_line = self.input_content[j]
if pattern_comment.search(c_line):
c_line = re.sub(r'\s*//', '//', c_line)
comment_line = c_line + comment_line
else:
break
return comment_line
@staticmethod
def get_return_statements(func):
func = pat_format_func.sub(r'\1\3\2', func)
m = pat_search_func.search(func)
if not m:
return None
logging.info('ret_type: %s, class_name: %s, func_name: %s', *m.group('ret_type', 'class_name', 'func_name'))
type_cls_func_name = '%s %s::%s' % m.group('ret_type', 'class_name', 'func_name')
if type_cls_func_name in RETURN_STATEMENTS:
logging.info('type_cls_func_name:[%s] matched!', type_cls_func_name)
return RETURN_STATEMENTS[type_cls_func_name]
type_cls_name = '%s %s::' % m.group('ret_type', 'class_name')
if type_cls_name in RETURN_STATEMENTS:
logging.info('type_cls_name:[%s] matched!', type_cls_name)
return RETURN_STATEMENTS[type_cls_name]
type_only = m.group('ret_type')
if type_only in RETURN_STATEMENTS:
logging.info('type_only:[%s] matched!', type_only)
return RETURN_STATEMENTS[type_only]
return None
@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'
return_statements = H2CC.get_return_statements(func)
if return_statements is not None:
function_def += return_statements
else:
all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (
len(all_items) > start + 1 and all_items[start + 1].startswith('*')) or return_type.startswith(
'std::unique_ptr'):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.info("Unhandled func[%s]", func)
logging.warning("Unhandled return type[%s]", return_type)
function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def
def collect_header_files(path):
"""
:param path:
:return:
"""
header_files = []
shared_includes_content = []
for root, dirs, files in os.walk(path):
files.sort()
dirs.sort()
for file in files:
if file.find("git") >= 0:
continue
if not file.endswith('.h'):
continue
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
header_files.append(file_path)
include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
shared_includes_content.append(include_str)
shared_includes_content.append('#include <iostream>\n')
return header_files, shared_includes_content
def generate_stub_file(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
target_header_files, shared_includes_content = collect_header_files(inc_dir)
for header_file in target_header_files:
if not file_endswith_white_list_suffix(header_file):
continue
cc_file = re.sub(r'([^/]+)\.h$', r'stub_\1.cc', header_file)
h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
h_2_cc.h2cc()
def gen_code(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
if not inc_dir.endswith('/'):
inc_dir += '/'
if not out_cc_dir.endswith('/'):
out_cc_dir += '/'
for include_dir_key_word in include_dir_key_words:
generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)
def main():
if len(sys.argv) != 3:
logging.error("script %s must have 2 input parameters!", sys.argv[0])
return
inc_dir = sys.argv[1]
out_cc_dir = sys.argv[2]
gen_code(inc_dir, out_cc_dir)
if __name__ == '__main__':
main()