import os
import tempfile
import argparse
from typing import List
from autofuse import ascendc_compile
import re
def str2bool(v):
v_lower = v.lower()
if v_lower in ['true', '1', 'yes', 'y']:
return True
elif v_lower in ['false', '0', 'no', 'n']:
return False
else:
raise ValueError(f"Invalid boolean value: '{v}'")
def camel_to_snake(camel_str):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel_str)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def gen_valid_name(t_name):
result = []
last_was_underscore = False
for c in t_name:
if c.isalnum():
result.append(c)
last_was_underscore = False
else:
if not last_was_underscore:
result.append('_')
last_was_underscore = True
ret_name = ''.join(result)
if ret_name and ret_name[0] == '_':
ret_name = ret_name[1:]
if ret_name and ret_name[0].isdigit():
ret_name = "t_" + ret_name
return ret_name
def parse_compile_args(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--graph_name', default='autofuse', type=str, help='Graph name.')
parser.add_argument('--output_file', required=True, type=str, help='Destination directory.')
parser.add_argument('--output_path', default='', type=str, help='Output directory.')
parser.add_argument('--force_unknown', default=False, type=str2bool, help='force unknown shape.')
parser.add_argument('--config_file', default='', type=str, help='PGO tiling config file after turning.')
parser.add_argument('--soc_version', default='Ascend910B', type=str, help='chip soc version.')
parser.add_argument('--compile_options', default='', type=str, help='Compile options of tiling and kernel.')
return parser.parse_args(argv)
def generate_file(dst_dir, file_name, text):
os.makedirs(dst_dir, exist_ok=True)
file_path = os.path.join(dst_dir, file_name)
with open(file_path, "w") as file:
file.write(text)
def parse_env_flags(env_name):
result = {}
flags = os.getenv(env_name)
if not flags:
return result
params = flags.split(';')
for param in params:
if '=' in param:
key_part, value_part = param.split('=', 1)
key = key_part.lstrip('-')
result[key] = value_part
return result
def get_dfx_env_result():
return parse_env_flags('AUTOFUSE_DFX_FLAGS')
def get_debug_flag():
dfx_dict = get_dfx_env_result()
return dfx_dict.get('codegen_compile_debug', "false").lower() == 'true'
def get_pgo_topn():
default_topn = 5
dfx_dict = get_dfx_env_result()
topn_str = dfx_dict.get('autofuse_pgo_topn', str(default_topn))
try:
topn = int(topn_str)
if topn < 0:
return default_topn
return topn
except ValueError:
return default_topn
def get_pgo_env_flag():
result = parse_env_flags('AUTOFUSE_FLAGS')
return result.get('autofuse_enable_pgo', "false").lower() == 'true'
def prepare_compile_context(argv, stage, tiling_repr):
args = parse_compile_args(argv)
args.stage = stage
args.tiling_repr = tiling_repr
if stage == 'host':
args.compile_options = (args.compile_options + " -D_GLIBCXX_USE_CXX11_ABI=0").strip()
args.graph_name = camel_to_snake(gen_valid_name(args.graph_name))
auto_cleanup = not args.output_path and not get_debug_flag()
if auto_cleanup:
temp_dir_ctx = tempfile.TemporaryDirectory()
args.temp_dir = temp_dir_ctx.name
return args, temp_dir_ctx, True
args.temp_dir = args.output_path if args.output_path else tempfile.mkdtemp()
return args, None, False
def execute_compile(sources, args):
tiling_def_file = "autofuse_tiling_data.h"
base_host_file = args.graph_name + "_tiling_func.cpp"
base_device_file = args.graph_name + "_op_kernel.cpp"
if args.stage in ['all', 'host']:
host_file_path = os.path.join(args.temp_dir, "host")
generate_file(host_file_path, tiling_def_file, sources['tiling_struct_code'])
generate_file(host_file_path, base_host_file, sources['host_impl_code'])
args.host_files = os.path.join(host_file_path, base_host_file)
if args.stage in ['all', 'device']:
device_file_path = os.path.join(args.temp_dir, "device")
generate_file(device_file_path, tiling_def_file, sources['tiling_struct_code'])
generate_file(device_file_path, base_device_file, sources['kernel_impl_code'])
args.device_files = os.path.join(device_file_path, base_device_file)
ascendc_compile.main(args)
return args.temp_dir
def compile_core(sources, argv: List[str], stage='all', tiling_repr=None):
args, temp_dir_ctx, auto_cleanup = prepare_compile_context(argv, stage, tiling_repr)
if not auto_cleanup:
return execute_compile(sources, args)
with temp_dir_ctx:
return execute_compile(sources, args)
def jit_compile(tiling_def, host_tiling, op_kernel, argv: List[str]):
return compile_core({
'tiling_struct_code': tiling_def,
'host_impl_code': host_tiling,
'kernel_impl_code': op_kernel
}, argv)
def host_compile(tiling_def_code, tiling_impl_code, argv: List[str]):
return compile_core({
'tiling_struct_code': tiling_def_code,
'host_impl_code': tiling_impl_code,
'kernel_impl_code': None
}, argv, 'host')
def kernel_compile(tiling_def_code, kernel_impl_code, argv: List[str], *, tiling_repr=None):
return compile_core({
'tiling_struct_code': tiling_def_code,
'host_impl_code': None,
'kernel_impl_code': kernel_impl_code
}, argv, 'device', tiling_repr)
def extract_time(line):
try:
time_str = line.split('#')[-1].strip()
if time_str == '1.79769e+308':
return float('inf')
return float(time_str)
except (ValueError, IndexError):
return float('inf')
def pgo_get_top_result(search_path, top_n=5):
with open(search_path, 'r') as file:
lines = [line.strip() for line in file if line.strip()]
if not lines:
return None, None, None
origin_line = lines[-1]
solution_set_line = lines[:-1]
sorted_lines = sorted(solution_set_line, key=extract_time)
if top_n == 0 or top_n > len(sorted_lines):
top_lines = sorted_lines
else:
top_lines = sorted_lines[:top_n]
return top_lines, origin_line, top_n
def pgo_write_config(config_path, tiling_data, is_last_result=False):
with open(config_path, 'w') as file:
if is_last_result:
file.write(f'1\n')
else:
file.write(f'0\n')
file.write(f"{tiling_data}\n")
file.flush()
def pgo_generate_config(search_path, config_path, topn=5):
with open(search_path, 'r') as file:
lines = [line.strip() for line in file if line.strip()]
target_lines = lines[-(topn + 1):]
result = min(target_lines, key=extract_time)
if extract_time(result) == float('inf'):
result = lines[-1]
pgo_write_config(config_path, result, is_last_result=True)