"""
Ascend C Host UT CSV 重构辅助脚本
功能:
1. 从 xxx_def.cpp 自动提取算子定义信息(输入/输出/属性)
2. 自动生成 CSV 列名行模板
3. 自动生成 param.h 结构体框架
"""
import re
import sys
import logging
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class TensorInfo:
"""Tensor 信息"""
name: str
param_type: str
data_types: List[str]
formats: List[str]
@dataclass
class AttrInfo:
"""属性信息"""
name: str
attr_type: str
cpp_type: str
default_value: str
def parse_def_file(file_path: str) -> Tuple[List[TensorInfo], List[TensorInfo], List[AttrInfo]]:
"""
解析 xxx_def.cpp 文件,提取算子定义信息
Args:
file_path: def.cpp 文件路径
Returns:
(inputs, outputs, attrs): 输入/输出/属性列表
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
inputs = []
outputs = []
attrs = []
input_pattern = r'this->Input\("([^"]+)"\)\s*\.ParamType\(([^)]+)\)'
for match in re.finditer(input_pattern, content):
name = match.group(1)
param_type = match.group(2)
dtype_pattern = rf'this->Input\("{name}"\).*?\.DataTypeList?\(\{{([^}}]+)\}}\)'
dtype_match = re.search(dtype_pattern, content, re.DOTALL)
data_types = []
if dtype_match:
dtypes_str = dtype_match.group(1)
data_types = re.findall(r'ge::DT_\w+', dtypes_str)
format_pattern = rf'this->Input\("{name}"\).*?\.FormatList\(\{{([^}}]+)\}}\)'
format_match = re.search(format_pattern, content, re.DOTALL)
formats = []
if format_match:
formats_str = format_match.group(1)
formats = re.findall(r'ge::FORMAT_\w+', formats_str)
inputs.append(TensorInfo(name, param_type, data_types[:5], formats[:1]))
output_pattern = r'this->Output\("([^"]+)"\)\s*\.ParamType\(([^)]+)\)'
for match in re.finditer(output_pattern, content):
name = match.group(1)
param_type = match.group(2)
dtype_pattern = rf'this->Output\("{name}"\).*?\.DataTypeList?\(\{{([^}}]+)\}}\)'
dtype_match = re.search(dtype_pattern, content, re.DOTALL)
data_types = []
if dtype_match:
dtypes_str = dtype_match.group(1)
data_types = re.findall(r'ge::DT_\w+', dtypes_str)
format_pattern = rf'this->Output\("{name}"\).*?\.FormatList\(\{{([^}}]+)\}}\)'
format_match = re.search(format_pattern, content, re.DOTALL)
formats = []
if format_match:
formats_str = format_match.group(1)
formats = re.findall(r'ge::FORMAT_\w+', formats_str)
outputs.append(TensorInfo(name, param_type, data_types[:5], formats[:1]))
attr_pattern = r'this->Attr\("([^"]+)"\)\.AttrType\(([^)]+)\)\.(\w+)\(([^)]+)\)'
for match in re.finditer(attr_pattern, content):
name = match.group(1)
attr_type = match.group(2)
cpp_type = match.group(3)
default_value = match.group(4)
attrs.append(AttrInfo(name, attr_type, cpp_type, default_value))
return inputs, outputs, attrs
def generate_csv_headers(inputs: List[TensorInfo], outputs: List[TensorInfo],
attrs: List[AttrInfo], test_type: str = "infershape") -> str:
"""
生成 CSV 列名行
Args:
inputs: 输入 Tensor 列表
outputs: 输出 Tensor 列表
attrs: 属性列表
test_type: 测试类型(tiling 或 infershape)
Returns:
CSV 列名行字符串
"""
headers = ["case_name", "expectResult"]
for tensor in inputs:
headers.extend([f"{tensor.name}_shape", f"{tensor.name}_dtype", f"{tensor.name}_format"])
for tensor in outputs:
headers.extend([f"{tensor.name}_shape", f"{tensor.name}_dtype", f"{tensor.name}_format"])
for attr in attrs:
headers.append(attr.name)
if test_type == "tiling":
headers.extend(["expectTilingKey", "expectTilingDataHash"])
headers.append("case_annotation")
return ",".join(headers)
def generate_param_struct(inputs: List[TensorInfo], outputs: List[TensorInfo],
attrs: List[AttrInfo], op_name: str) -> str:
"""
生成 param.h 结构体框架(简化版)
Args:
inputs: 输入 Tensor 列表
outputs: 输出 Tensor 列表
attrs: 属性列表
op_name: 算子名称(小写)
Returns:
结构体代码框架
"""
op_name_camel = "".join(word.capitalize() for word in op_name.split("_"))
op_name_upper = op_name.upper()
code = f"// {op_name_camel} 参数结构体框架\n"
code += f"// 需要手动调整和补充完整\n\n"
code += "// 属性字段\n"
for attr in attrs:
cpp_type_map = {
"Int": "int64_t",
"Float": "float",
"String": "std::string",
"Bool": "bool"
}
cpp_type = cpp_type_map.get(attr.cpp_type, "int64_t")
code += f"{cpp_type} {attr.name}; // {attr.attr_type}, default: {attr.default_value}\n"
code += "\n// 输入 Tensor 字段(Tiling 测试)\n"
for tensor in inputs:
code += f"gert::TilingContextPara::TensorDescription {tensor.name} = TD_DEFAULT; // {tensor.param_type}\n"
code += "\n// 输出 Tensor 字段(Tiling 测试)\n"
for tensor in outputs:
code += f"gert::TilingContextPara::TensorDescription {tensor.name} = TD_DEFAULT; // {tensor.param_type}\n"
return code
def main():
logging.basicConfig(level=logging.INFO, format='%(message)s', stream=sys.stdout)
if len(sys.argv) < 2:
logging.error("使用方法: python generate_csv_template.py <xxx_def.cpp>")
logging.error("示例: python generate_csv_template.py fused_infer_attention_score_def.cpp")
sys.exit(1)
def_file = sys.argv[1]
inputs, outputs, attrs = parse_def_file(def_file)
logging.info(f"\n找到 {len(inputs)} 个输入 Tensor:")
for tensor in inputs:
logging.info(f" - {tensor.name}: {tensor.param_type}")
logging.info(f"\n找到 {len(outputs)} 个输出 Tensor:")
for tensor in outputs:
logging.info(f" - {tensor.name}: {tensor.param_type}")
logging.info(f"\n找到 {len(attrs)} 个属性:")
for attr in attrs:
logging.info(f" - {attr.name}: {attr.cpp_type} ({attr.attr_type})")
logging.info("\n" + "=" * 80)
logging.info("CSV 列名行(InferShape 测试):")
logging.info("=" * 80)
csv_headers = generate_csv_headers(inputs, outputs, attrs, "infershape")
logging.info(csv_headers)
logging.info("\n" + "=" * 80)
logging.info("CSV 列名行(Tiling 测试):")
logging.info("=" * 80)
csv_headers_tiling = generate_csv_headers(inputs, outputs, attrs, "tiling")
logging.info(csv_headers_tiling)
op_name = def_file.replace("_def.cpp", "")
logging.info("\n" + "=" * 80)
logging.info("参数结构体框架:")
logging.info("=" * 80)
param_struct = generate_param_struct(inputs, outputs, attrs, op_name)
logging.info(param_struct)
output_file = f"{op_name}_csv_template.txt"
with open(output_file, 'w', encoding='utf-8') as f:
f.write("=" * 80 + "\n")
f.write("CSV 列名行(InferShape 测试):\n")
f.write("=" * 80 + "\n")
f.write(csv_headers + "\n\n")
f.write("=" * 80 + "\n")
f.write("CSV 列名行(Tiling 测试):\n")
f.write("=" * 80 + "\n")
f.write(csv_headers_tiling + "\n\n")
f.write("=" * 80 + "\n")
f.write("参数结构体框架:\n")
f.write("=" * 80 + "\n")
f.write(param_struct + "\n")
logging.info(f"\n已保存到文件: {output_file}")
logging.info(f"\n模板中需要修改的地方已在注释中标注,完善模板后请删除所有的注释!")
if __name__ == "__main__":
main()