"""
UT 生成器 CLI - 命令行入口
该脚本是 asc-api-ut-gen skill 的核心入口,用于生成各类 AscendC API 的单元测试代码。
使用方式:
# 通过 Skill 调用
python ut_generator_cli.py --config <config_file>
python ut_generator_cli.py --type aiv --api Add --chip ascend910b1
# 输出配置文件模板
python ut_generator_cli.py --template-config
"""
import argparse
import json
import os
import re
import sys
import time
from dataclasses import asdict
from typing import Dict, List, Any, Optional, TextIO
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
from ut_generator import (
ApiType, ChipArch, TestCase, UTConfig,
create_generator, DTYPE_MAP, GENERATOR_DTYPE_MAP, ARCH_DIR_MAP, NPU_ARCH_MAP,
API_TYPE_MAP, API_TYPE_CHOICES, API_RESTRICTIONS,
AIV_GENERIC_SCALAR_TENSOR_DISPATCH_APIS,
CHIP_ARCH_BY_NAME, REFERENCE_CONSTRAINT_FILES, ensure_reference_constraints_loaded,
get_adv_profile_output_path
)
def write_line(message: str = "", stream: Optional[TextIO] = None) -> None:
"""Write a line to the requested stream."""
output_stream = stream if stream is not None else sys.stdout
output_stream.write(f"{message}\n")
def parse_coverage_report(report_path: Optional[str]) -> Dict[str, Any]:
"""读取 cov_report HTML,返回当前 Lines / Functions 覆盖率。"""
if not report_path:
return {
"status": "unavailable",
"lines": None,
"functions": None,
"source": None,
"reason": "未提供 coverage report 路径"
}
if not os.path.exists(report_path):
return {
"status": "unavailable",
"lines": None,
"functions": None,
"source": report_path,
"reason": f"coverage report 不存在: {report_path}"
}
with open(report_path, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
def normalize_html_text(fragment: str) -> str:
text = re.sub(r"<[^>]+>", " ", fragment)
return re.sub(r"\s+", " ", text).strip()
def find_percent(label: str) -> Optional[float]:
label_re = re.escape(label)
header_pattern = re.compile(
rf'<td\b(?=[^>]*class=["\'][^"\']*\bheaderItem\b[^"\']*["\'])[^>]*>'
rf'\s*{label_re}:?\s*</td>\s*'
rf'<td\b(?=[^>]*class=["\'][^"\']*\bheaderCovTableEntry(?:Lo|Med|Hi)?\b[^"\']*["\'])[^>]*>'
rf'\s*([0-9]+(?:\.[0-9]+)?)\s*%',
re.IGNORECASE | re.DOTALL,
)
match = header_pattern.search(content)
if match:
return float(match.group(1))
for row in re.findall(r"<tr\b[^>]*>.*?</tr>", content, re.IGNORECASE | re.DOTALL):
row_text = normalize_html_text(row)
if not re.search(rf"(?:^|\s){label_re}\s*:", row_text, re.IGNORECASE):
continue
match = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*%", row_text)
if match:
return float(match.group(1))
return None
lines = find_percent("Lines")
functions = find_percent("Functions")
if lines is None and functions is None:
return {
"status": "unavailable",
"lines": None,
"functions": None,
"source": report_path,
"reason": "未在 coverage report 中解析到 Lines / Functions 百分比"
}
return {
"status": "available",
"lines": lines,
"functions": functions,
"source": report_path,
"reason": None
}
def get_token_usage(args: argparse.Namespace) -> Dict[str, Any]:
"""获取 token 消耗。CLI 本身不消耗 LLM token,只记录外部传入值。"""
prompt = args.prompt_tokens
completion = args.completion_tokens
total = args.total_tokens
source = "cli_args" if any(value is not None for value in [prompt, completion, total]) else None
reason = None
def parse_env_int(value: Optional[str], name: str, invalid: List[str]) -> Optional[int]:
if value is None:
return None
stripped = value.strip()
if not stripped:
return None
try:
return int(stripped)
except ValueError:
invalid.append(f"{name}={value!r}")
return None
if source is None:
env_prompt = os.getenv("ASC_API_UT_PROMPT_TOKENS")
env_completion = os.getenv("ASC_API_UT_COMPLETION_TOKENS")
env_total = os.getenv("ASC_API_UT_TOTAL_TOKENS")
if any(value is not None for value in [env_prompt, env_completion, env_total]):
source = "env"
invalid_env = []
prompt = parse_env_int(env_prompt, "ASC_API_UT_PROMPT_TOKENS", invalid_env)
completion = parse_env_int(env_completion, "ASC_API_UT_COMPLETION_TOKENS", invalid_env)
total = parse_env_int(env_total, "ASC_API_UT_TOTAL_TOKENS", invalid_env)
if invalid_env:
reason = "忽略非法 token 环境变量: " + ", ".join(invalid_env)
if total is None and prompt is not None and completion is not None:
total = prompt + completion
if source is None:
return {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
"source": "not_provided",
"reason": "本地 CLI 生成不消耗 LLM token,调用环境也未提供 token usage"
}
return {
"prompt_tokens": prompt,
"completion_tokens": completion,
"total_tokens": total,
"source": source,
"reason": reason
}
def print_generation_report(report: Dict[str, Any], stream) -> None:
"""打印生成后的简要报告。"""
token_usage = report["token_usage"]
coverage = report["current_coverage"]
def fmt(value):
return "未提供" if value is None else value
write_line("生成报告:", stream)
write_line(
" Token 消耗: "
f"prompt={fmt(token_usage['prompt_tokens'])}, "
f"completion={fmt(token_usage['completion_tokens'])}, "
f"total={fmt(token_usage['total_tokens'])}"
f" ({token_usage['source']})",
stream,
)
if token_usage.get("reason"):
write_line(f" Token 说明: {token_usage['reason']}", stream)
write_line(f" 生成耗时: {report['elapsed_seconds']}s", stream)
if coverage["status"] == "available":
write_line(
f" 当前覆盖率: Lines={coverage['lines']}%, Functions={coverage['functions']}%",
stream,
)
write_line(f" 覆盖率来源: {coverage['source']}", stream)
else:
write_line(f" 当前覆盖率: 未获取 ({coverage['reason']})", stream)
def ensure_parent_dir(path: str) -> None:
"""创建文件父目录;当前目录文件不需要创建目录。"""
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
def create_default_config() -> Dict[str, Any]:
"""创建默认配置模板"""
return {
"api_type": "aiv",
"api_name": "Add",
"chip": "ascend910b1",
"test_cases": [
{
"name": "Add_half_256",
"data_size": 256,
"dtype": "half",
"input_count": 2,
"has_mask": False,
"additional_params": {}
},
{
"name": "Add_float_256",
"data_size": 256,
"dtype": "float",
"input_count": 2,
"has_mask": False,
"additional_params": {}
}
],
"kernel_params": {},
"custom_includes": [],
"output_dir": None
}
def default_kernel_params(api_type: str) -> Dict[str, int]:
"""Return default kernel params only for templates that consume them."""
if api_type == ApiType.AIC.value:
return {"m": 16, "k": 64, "n": 16}
return {}
def default_input_count(api_type: str, api_name: str) -> int:
"""Return the safest CLI default input count for generic templates."""
if api_type == ApiType.AIV.value and api_name.lower() in AIV_GENERIC_SCALAR_TENSOR_DISPATCH_APIS:
return 1
return 2
def validate_config(config: Dict[str, Any]) -> List[str]:
"""验证配置文件"""
errors = []
valid_types = API_TYPE_CHOICES
if config.get('api_type') not in valid_types:
errors.append(f"无效的 api_type: {config.get('api_type')}, 有效值: {valid_types}")
valid_chips = list(CHIP_ARCH_BY_NAME.keys())
if config.get('chip') not in valid_chips:
errors.append(f"无效的 chip: {config.get('chip')}, 有效值: {valid_chips}")
if not config.get('api_name'):
errors.append("api_name 不能为空")
chip = config.get('chip')
api_type = config.get('api_type')
restriction = API_RESTRICTIONS.get(api_type, {})
allowed_chips = restriction.get('allowed_chips', [])
if allowed_chips and chip not in allowed_chips:
message = restriction.get(
'message',
f"{api_type} API 不支持当前芯片: {{chip}}"
)
errors.append(message.format(chip=chip))
valid_dtypes = list(GENERATOR_DTYPE_MAP.keys())
for tc in config.get('test_cases', []):
if tc.get('dtype') not in valid_dtypes:
errors.append(
f"当前通用生成器不支持直接生成 dtype: {tc.get('dtype')}, "
f"可直接生成值: {valid_dtypes}; 其他内置类型需按目标 API 自定义初始化"
)
return errors
def parse_config(config: Dict[str, Any]) -> UTConfig:
"""解析配置文件,创建 UTConfig 对象"""
test_cases = []
for tc in config.get('test_cases', []):
test_cases.append(TestCase(
name=tc.get('name', f"{config['api_name']}_case"),
data_size=tc.get('data_size', 256),
dtype=tc.get('dtype', 'half'),
input_count=tc.get('input_count', 2),
has_mask=tc.get('has_mask', False),
additional_params=tc.get('additional_params', {})
))
return UTConfig(
api_type=API_TYPE_MAP[config['api_type']],
api_name=config['api_name'],
chip=CHIP_ARCH_BY_NAME[config['chip']],
test_cases=test_cases,
kernel_params=config.get('kernel_params', {}),
custom_includes=config.get('custom_includes', []),
output_dir=config.get('output_dir')
)
def get_output_path(config: UTConfig) -> str:
"""获取输出文件路径"""
if config.output_dir:
return config.output_dir
base_dir = ARCH_DIR_MAP[config.chip]
if config.api_type == ApiType.AIV:
return f"tests/api/basic_api/{base_dir}/{base_dir}_aiv/test_operator_{config.api_name.lower()}.cpp"
elif config.api_type == ApiType.AIC:
return f"tests/api/basic_api/{base_dir}/{base_dir}_aic/test_operator_{config.api_name.lower()}.cpp"
elif config.api_type == ApiType.C_API:
npu_arch = NPU_ARCH_MAP[config.chip]
return f"tests/api/c_api/npu_arch_{npu_arch}/vector_compute/test_{config.api_name.lower()}.cpp"
elif config.api_type == ApiType.ADV:
profile_output = get_adv_profile_output_path(config.api_name, config.kernel_params)
if profile_output:
return profile_output
return f"tests/api/adv_api/{config.api_name.lower()}/test_operator_{config.api_name.lower()}.cpp"
elif config.api_type == ApiType.SIMT:
return f"tests/api/simt_api/{base_dir}_simt/test_operator_simt_{config.api_name.lower()}.cpp"
elif config.api_type == ApiType.REG_COMPUTE:
return f"tests/api/reg_compute_api/{base_dir}_reg_compute/test_operator_reg_{config.api_name.lower()}.cpp"
elif config.api_type == ApiType.UTILS:
return f"tests/api/utils/test_{config.api_name.lower()}.cpp"
return f"tests/api/test_{config.api_name.lower()}.cpp"
def main():
"""主函数"""
start_time = time.monotonic()
parser = argparse.ArgumentParser(
description='AscendC API UT 代码生成器 CLI',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
示例:
# 从配置文件生成
%(prog)s --config ut_config.json
# 命令行参数生成
%(prog)s --type aiv --api Add --chip ascend910b1
# 输出配置模板
%(prog)s --template-config
# 列出支持的配置
%(prog)s --list-supported
'''
)
parser.add_argument('--config', '-c', type=str,
help='配置文件路径 (JSON 格式)')
parser.add_argument('--type', '-t', choices=API_TYPE_CHOICES,
help='API 类型')
parser.add_argument('--api', '-a', type=str,
help='API 名称 (如: Add, Mmad, asc_add)')
parser.add_argument('--chip', choices=list(CHIP_ARCH_BY_NAME.keys()),
help='目标芯片架构')
parser.add_argument('--output', '-o', type=str, default=None,
help='输出文件路径')
parser.add_argument('--output-dir', type=str, default=None,
help='输出目录(自动生成文件名)')
parser.add_argument('--data-size', '-d', type=int, default=256,
help='测试数据大小 (默认: 256)')
parser.add_argument('--dtype', type=str, default='half',
choices=list(GENERATOR_DTYPE_MAP.keys()),
help='通用模板可直接初始化的数据类型 (默认: half)')
parser.add_argument('--test-count', type=int, default=3,
help='生成测试用例数量 (默认: 3)')
parser.add_argument('--template-config', action='store_true',
help='输出配置文件模板并退出')
parser.add_argument('--list-supported', action='store_true',
help='列出支持的配置选项')
parser.add_argument('--validate', action='store_true',
help='仅验证配置,不生成代码')
parser.add_argument('--coverage-report', type=str, default=None,
help='当前覆盖率报告 HTML 路径,用于报告 Lines / Functions 覆盖率')
parser.add_argument('--prompt-tokens', type=int, default=None,
help='外部调用环境记录的 prompt token 数')
parser.add_argument('--completion-tokens', type=int, default=None,
help='外部调用环境记录的 completion token 数')
parser.add_argument('--total-tokens', type=int, default=None,
help='外部调用环境记录的 total token 数')
parser.add_argument('--report-json', type=str, default=None,
help='将生成报告写入指定 JSON 文件')
args = parser.parse_args()
if args.template_config:
template = create_default_config()
write_line(json.dumps(template, indent=2, ensure_ascii=False))
return 0
try:
ensure_reference_constraints_loaded()
except RuntimeError as exc:
write_line(f"结构化 reference 加载失败: {exc}", sys.stderr)
return 1
if args.list_supported:
write_line("=== 结构化 reference 约束 ===")
for name, path in REFERENCE_CONSTRAINT_FILES.items():
write_line(f" - {name}: {path}")
write_line("=== 支持的 API 类型 ===")
for t in API_TYPE_CHOICES:
write_line(f" - {t}")
write_line("\n=== 支持的芯片架构 ===")
for chip, arch_dir in ARCH_DIR_MAP.items():
npu_arch = NPU_ARCH_MAP[chip]
write_line(f" - {chip.value}: impl={arch_dir}, NPU_ARCH={npu_arch}")
write_line("\n=== 文档内置数据类型 ===")
for dtype, info in DTYPE_MAP.items():
init = "generic-init" if info.get("generic_ut_generation", False) else "api-specific-init"
write_line(f" - {dtype}: size={info['size']}, bits={info.get('bit_width', 'unknown')}, {init}")
write_line("\n=== 通用 UT 生成可直接初始化的数据类型 ===")
for dtype, info in GENERATOR_DTYPE_MAP.items():
write_line(f" - {dtype}: size={info['size']}")
write_line("\n=== 架构限制 ===")
for api_type, restriction in API_RESTRICTIONS.items():
allowed = ", ".join(restriction.get('allowed_chips', []))
write_line(f" - {api_type}: {allowed}")
return 0
if args.config:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = json.load(f)
elif args.type and args.api and args.chip:
test_cases = []
dtypes = [args.dtype]
if args.dtype == 'half' and args.test_count > 1:
dtypes.append('float')
for i, dtype in enumerate(dtypes[:args.test_count]):
test_cases.append({
"name": f"{args.api}_{dtype}_{args.data_size * (i + 1)}",
"data_size": args.data_size * (i + 1),
"dtype": dtype,
"input_count": default_input_count(args.type, args.api),
"has_mask": False,
"additional_params": {}
})
config_dict = {
"api_type": args.type,
"api_name": args.api,
"chip": args.chip,
"test_cases": test_cases,
"kernel_params": default_kernel_params(args.type),
"custom_includes": [],
"output_dir": args.output_dir
}
else:
parser.error("必须提供 --config 或 (--type, --api, --chip)")
errors = validate_config(config_dict)
if errors:
write_line("配置验证失败:", sys.stderr)
for error in errors:
write_line(f" - {error}", sys.stderr)
return 1
if args.validate:
write_line("配置验证通过")
return 0
try:
config = parse_config(config_dict)
except Exception as e:
write_line(f"配置解析失败: {e}", sys.stderr)
return 1
try:
generator = create_generator(config)
code = generator.generate()
except Exception as e:
write_line(f"代码生成失败: {e}", sys.stderr)
return 1
output_path = args.output or get_output_path(config)
if output_path == '-' or output_path is None:
write_line(code)
report_stream = sys.stderr
else:
ensure_parent_dir(output_path)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(code)
write_line(f"UT 文件已生成: {output_path}")
report_stream = sys.stdout
report = {
"api_type": config.api_type.value,
"api_name": config.api_name,
"chip": config.chip.value,
"output_path": output_path,
"elapsed_seconds": round(time.monotonic() - start_time, 3),
"token_usage": get_token_usage(args),
"current_coverage": parse_coverage_report(args.coverage_report),
}
print_generation_report(report, report_stream)
if args.report_json:
ensure_parent_dir(args.report_json)
with open(args.report_json, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
return 0
if __name__ == '__main__':
sys.exit(main())