"""
自定义代码规范检查(仅检查 ruff 无法覆盖的项目特定规则)
"""
import argparse
import json
import re
import sys
from pathlib import Path
from typing import List, Dict
RULES = {
"CODE-001": {"level": "error", "name": "License 头缺失或格式错误"},
"CODE-002": {"level": "warning", "name": "从 core/ 导入(建议迁移到 core_v2/)"},
"CODE-004": {"level": "error", "name": "参数值不在有效范围内"},
"CODE-013": {"level": "error", "name": "错误的包名引用"},
"CODE-010": {"level": "info", "name": "TODO/FIXME/HACK 注释"},
}
LICENSE_MARKERS = [
"Copyright",
"Huawei Technologies",
"Apache License, Version 2.0",
]
def check_license_header(file_path: Path, content: str) -> List[Dict]:
"""检查 License 头(所有 .py 文件都要求 Apache 2.0 头)"""
issues = []
header_text = '\n'.join(content.split('\n')[:15])
missing = [m for m in LICENSE_MARKERS if m not in header_text]
if missing:
issues.append({
"rule": "CODE-001",
"level": "error",
"file": str(file_path),
"line": 1,
"message": f"License 头缺失或不完整(缺少: {', '.join(missing)})",
"suggestion": "添加 Apache 2.0 License 头"
})
return issues
def check_imports(file_path: Path, content: str) -> List[Dict]:
"""检查导入规范"""
issues = []
if "kernel_verifier.py" in str(file_path):
return issues
lines = content.split('\n')
for i, line in enumerate(lines, 1):
if line.strip().startswith('#'):
continue
if re.search(r'from\s+akg_agents\.core\.', line):
issues.append({
"rule": "CODE-002",
"level": "warning",
"file": str(file_path),
"line": i,
"message": f"从 core/ 导入: {line.strip()}",
"suggestion": "建议迁移到 core_v2/(core/ 正在迁移)"
})
return issues
def check_parameter_values(file_path: Path, content: str) -> List[Dict]:
"""检查参数值"""
issues = []
valid_values = {
"backend": ["cuda", "ascend", "cpu"],
"framework": ["torch", "mindspore"],
"dsl": [
"triton", "triton_cuda", "triton_ascend", "cpp",
"cuda_c", "tilelang_cuda", "ascendc", "pypto",
],
"arch": [
"a100", "v100",
"ascend910b1", "ascend910b2",
"ascend910b3", "ascend910b4",
"ascend910_9362", "ascend910_9372",
"ascend910_9381", "ascend910_9382",
"ascend910_9391", "ascend910_9392",
"ascend310p3", "x86_64", "aarch64",
],
}
lines = content.split('\n')
for i, line in enumerate(lines, 1):
if line.strip().startswith('#'):
continue
for param, valid in valid_values.items():
match = re.search(rf'{param}\s*=\s*["\']([^"\']+)["\']', line)
if match and ('{' in match.group(1) or '}' in match.group(1)):
continue
if match and match.group(1) not in valid:
issues.append({
"rule": "CODE-004",
"level": "error",
"file": str(file_path),
"line": i,
"message": f"{param}='{match.group(1)}' 不在有效值范围",
"suggestion": f"有效值: {', '.join(valid)}"
})
return issues
WRONG_PACKAGE_NAMES = {
"ai_kernel_generator": "旧包名 'ai_kernel_generator'",
"kernel_generator": "旧包名 'kernel_generator'",
}
WRONG_PACKAGE_REGEX = re.compile(
r'\bakg_agent\b(?!s)'
)
def check_package_name(file_path: Path, content: str) -> List[Dict]:
"""检查包名"""
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
if line.strip().startswith('#') or 'import' not in line:
continue
for wrong_name, desc in WRONG_PACKAGE_NAMES.items():
if wrong_name in line:
issues.append({
"rule": "CODE-013",
"level": "error",
"file": str(file_path),
"line": i,
"message": f"错误的包名 ({desc}): {line.strip()}",
"suggestion": "应为 'akg_agents'"
})
if WRONG_PACKAGE_REGEX.search(line):
issues.append({
"rule": "CODE-013",
"level": "error",
"file": str(file_path),
"line": i,
"message": f"包名少了 s (akg_agent → akg_agents): {line.strip()}",
"suggestion": "应为 'akg_agents'(复数)"
})
return issues
def check_todo_comments(file_path: Path, content: str) -> List[Dict]:
"""检查 TODO 注释"""
issues = []
lines = content.split('\n')
for i, line in enumerate(lines, 1):
if re.search(r'#\s*(TODO|FIXME|XXX|HACK)', line, re.IGNORECASE):
issues.append({
"rule": "CODE-010",
"level": "info",
"file": str(file_path),
"line": i,
"message": f"待办注释: {line.strip()[:80]}",
"suggestion": "考虑创建 Issue 跟踪或在本次提交中完成"
})
return issues
def check_file(file_path: Path, repo_path: Path) -> List[Dict]:
"""检查单个文件"""
issues = []
if file_path.suffix != '.py':
return issues
try:
with open(repo_path / file_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
issues.append({
"rule": "ERROR",
"level": "error",
"file": str(file_path),
"line": 0,
"message": f"无法读取文件: {str(e)}",
"suggestion": "检查文件是否存在且可读"
})
return issues
issues.extend(check_license_header(file_path, content))
issues.extend(check_imports(file_path, content))
issues.extend(check_parameter_values(file_path, content))
issues.extend(check_package_name(file_path, content))
issues.extend(check_todo_comments(file_path, content))
return issues
def main():
parser = argparse.ArgumentParser(description="自定义代码规范检查")
parser.add_argument("--files", required=True, help="文件列表(空格分隔)")
parser.add_argument("--repo-path", default=".", help="仓库路径")
parser.add_argument("--output", help="输出 JSON 文件")
args = parser.parse_args()
repo_path = Path(args.repo_path).resolve()
files = [f for f in args.files.split() if f]
all_issues = []
for file_str in files:
all_issues.extend(check_file(Path(file_str), repo_path))
errors = [i for i in all_issues if i["level"] == "error"]
infos = [i for i in all_issues if i["level"] == "info"]
result = {
"status": "fail" if errors else "pass",
"total_errors": len(errors),
"total_infos": len(infos),
"issues": all_issues,
"summary": {
"files_checked": len(files),
"errors": len(errors),
"infos": len(infos)
}
}
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
else:
print(json.dumps(result, indent=2, ensure_ascii=False))
sys.exit(1 if errors else 0)
if __name__ == "__main__":
main()