import os
import re
import logging
def should_skip_directory(dir_name):
"""
判断是否应该跳过该目录
"""
skip_dirs = {
'build', 'cmake', 'common', 'docs', 'examples',
'experimental', 'scripts', 'tests', 'third_party'
}
return dir_name in skip_dirs
def parse_foreach_config(config_str):
"""
解析 FOREACH_OPDEF 中的配置字符串
"""
config_mapping = {
'A2': 'ascend910b',
'910_93': 'ascend910_93',
'A5': 'ascend950',
'910B': 'ascend910b',
'910B_93': 'ascend910_93',
'910B_95': 'ascend950',
'950': 'ascend950',
'910': 'ascend910',
'910_55': 'ascend910_55',
}
found_configs = []
config_str_upper = config_str.upper()
priority_checks = [
('A2', 'ascend910b'),
('910_93', 'ascend910_93'),
('A5', 'ascend950'),
('910_55', 'ascend910_55'),
('910B', 'ascend910b'),
('910B_93', 'ascend910_93'),
('910B_95', 'ascend950'),
('950', 'ascend950'),
('910', 'ascend910'),
]
for key, value in priority_checks:
if key in config_str_upper and value not in found_configs:
found_configs.append(value)
return found_configs
def extract_static_map_configs(content):
"""
从静态map中提取配置名称
"""
configs = []
map_patterns = [
r'static\s+const\s+std::map<std::string[^>]*>\s+\w+\s*=\s*\{([^}]+)\}',
r'\{"([a-zA-Z0-9_]+)"[^}]*\}',
]
for pattern in map_patterns:
matches = re.findall(pattern, content, re.DOTALL)
for match in matches:
config_matches = re.findall(r'"([a-zA-Z0-9_]+)"', match)
configs.extend(config_matches)
return list(set(configs))
def extract_set_ascend_config_calls(content):
"""
提取 SetAscendConfig 调用中的配置名称
"""
configs = []
pattern1 = r'SetAscendConfig\([^,]+,\s*"([^"]+)"\)'
pattern2 = r'SetAscendConfig\([^,]+,\s*"([^"]+)",\s*"([^"]+)"\)'
matches1 = re.findall(pattern1, content)
for match in matches1:
if match not in configs:
configs.append(match)
matches2 = re.findall(pattern2, content)
for match in matches2:
version, dst_version = match
if version not in configs:
configs.append(version)
if dst_version not in configs:
configs.append(dst_version)
return list(set(configs))
def extract_foreach_opdef_configs(content):
"""
提取 FOREACH_OPDEF 相关格式的配置
"""
configs = []
pattern1 = r'FOREACH_OPDEF\(([^,]+),'
matches1 = re.findall(pattern1, content)
for match in matches1:
config_str = match.strip()
configs.extend(parse_foreach_config(config_str))
pattern2 = r'FOREACH_OPDEF_END_([^(]+)\('
matches2 = re.findall(pattern2, content)
for match in matches2:
config_str = match.strip()
configs.extend(parse_foreach_config(config_str))
return list(set(configs))
def extract_traditional_aicore_configs(content):
"""
提取传统的 AICore 配置名称
"""
configs = []
traditional_patterns = [
r'this->AICore\(\)\.AddConfig\("([a-zA-Z0-9_]+)"',
r'\.AddConfig\("([a-zA-Z0-9_]+)"',
r'AddConfig\("([a-zA-Z0-9_]+)"',
]
for pattern in traditional_patterns:
matches = re.findall(pattern, content)
for match in matches:
if match not in configs:
configs.append(match)
return configs
def extract_ai_core_configs(file_path):
"""
从 _def.cpp 文件中提取 AICore 配置名称
"""
configs = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
traditional_configs = extract_traditional_aicore_configs(content)
if traditional_configs:
configs.extend(traditional_configs)
foreach_configs = extract_foreach_opdef_configs(content)
if foreach_configs:
configs.extend(foreach_configs)
static_map_configs = extract_static_map_configs(content)
set_ascend_configs = extract_set_ascend_config_calls(content)
all_other_configs = list(set(static_map_configs + set_ascend_configs))
if all_other_configs:
configs.extend(all_other_configs)
return list(set(configs))
except Exception as e:
logging.error(f"读取文件 {file_path} 时出错: {e}")
return []
def ceil_div(a, b):
return (a + b - 1) // b
def split_list_by_num_groups(lst, num_groups):
avg = ceil_div(len(lst), num_groups)
out = []
last = 0.0
for _ in range(num_groups):
val = int(round(last + avg))
out.append(lst[int(last):val])
last = val
return out
GROUPING_CONFIGS = {
"default": {},
"ascend950": {}
}
def grouped(repository_path, soc, group_size):
if soc in ("950", "ascend950"):
config = GROUPING_CONFIGS.get("ascend950")
else:
config = GROUPING_CONFIGS.get("default")
result = [[] for _ in range(len(config))]
remain = []
zero_tensor_num = 0
for root, dirs, files in os.walk(repository_path):
dirs[:] = [d for d in dirs if not should_skip_directory(d)]
for file in files:
if file.endswith('_def.cpp'):
full_path = os.path.join(root, file)
op_name = file.replace('_def.cpp', '')
ai_core_configs = extract_ai_core_configs(full_path)
current_path = full_path
for _ in range(3):
current_path = os.path.dirname(current_path)
if soc in ai_core_configs:
matched = False
for idx, op_list in config.items():
if op_name in op_list:
result[idx].append(op_name)
matched = True
break
if not matched:
remain.append(op_name)
filtered_result = []
len_size = len(result)
for i in range(len_size):
if len(result[i]) == 0:
zero_tensor_num += 1
else:
filtered_result.append(result[i])
remain = sorted(remain)
remain = split_list_by_num_groups(remain, group_size - len_size + zero_tensor_num if group_size > 8 else group_size)
result.extend(remain)
return result
def main(repository_path, soc, group_size):
return grouped(repository_path, soc, group_size)