import os
import sys
import importlib.util
import argparse
def test_case_file(filepath, root_dir):
"""Test a single case file"""
relative_path = os.path.relpath(filepath, root_dir)
print(f"\n=== 开始测试文件: {relative_path} ===")
print(" 正在检测可用设备...")
import torch
device = None
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
print(f" Using GPU: {torch.cuda.get_device_name(device)}")
elif hasattr(torch, 'npu') and torch.npu.is_available():
device_id = os.environ.get('DEVICE_ID', '0')
os.environ['DEVICE_ID'] = str(device_id)
device = torch.device("npu")
torch.npu.manual_seed(0)
torch.npu.set_device(int(device_id))
print(f" Using NPU: device {device_id}")
else:
device = torch.device("cpu")
print(f" Using CPU (GPU/NPU not available)")
try:
print(f" 正在导入模块: {relative_path}")
module_path = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '')
print(f" 模块路径: {module_path}")
spec = importlib.util.spec_from_file_location(module_path, filepath)
print(f" 模块spec创建成功")
module = importlib.util.module_from_spec(spec)
print(f" 模块对象创建成功")
spec.loader.exec_module(module)
print(f" 模块执行成功")
print(f" 正在获取Model类...")
model_class = getattr(module, 'Model')
print(f" Model类获取成功: {model_class}")
print(f" 正在获取get_inputs函数...")
get_inputs_func = getattr(module, 'get_inputs')
print(f" get_inputs函数获取成功: {get_inputs_func}")
print(f" 正在获取get_init_inputs函数...")
get_init_inputs_func = getattr(module, 'get_init_inputs')
print(f" get_init_inputs函数获取成功: {get_init_inputs_func}")
print(f" 正在获取初始化参数...")
init_params = get_init_inputs_func()
print(f" 初始化参数: {init_params}")
print(f" 正在初始化模型...")
if isinstance(init_params, list):
model = model_class(*init_params)
else:
model = model_class()
print(f" 模型初始化成功: {type(model)}")
if device.type != "cpu":
print(f" 正在将模型移动到设备: {device}")
model = model.to(device)
print(f" 模型移动成功")
print(f" 正在获取输入数据...")
inputs = get_inputs_func()
print(f" 输入数据获取成功: {type(inputs)}")
print(f" 正在处理输入数据...")
if isinstance(inputs, list):
print(f" 输入是列表,长度: {len(inputs)}")
device_inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]
print(f" 输入数据移动到设备完成")
print(f" 正在运行模型...")
output = model(*device_inputs)
else:
print(f" 输入是单个张量")
device_inputs = inputs.to(device) if isinstance(inputs, torch.Tensor) else inputs
print(f" 输入数据移动到设备完成")
print(f" 正在运行模型...")
output = model(device_inputs)
print(f" === 文件 {relative_path} 测试完成: PASSED ===")
return True
except Exception as e:
print(f" === 文件 {relative_path} 测试失败: {str(e)} ===")
import traceback
traceback.print_exc()
return False
def test_all_cases(root_dir):
"""Test all Python case files in the akg_kernels_bench directory"""
print(f"\n=== 开始测试所有静态形状用例 ===")
print(f" 根目录: {root_dir}")
sys.path.insert(0, root_dir)
print(f" Python路径已更新")
total_tests = 0
passed_tests = 0
failed_tests = 0
failed_files = []
print(f" 开始遍历目录查找测试文件...")
for dirpath, dirnames, filenames in os.walk(root_dir):
print(f" 正在检查目录: {dirpath}")
if 'dynamic_shape' in dirpath:
print(f" 跳过dynamic_shape目录")
continue
print(f" 检查Python文件...")
for filename in filenames:
if (filename.endswith('.py') and
filename != 'test_all_cases.py' and
'test_single_case' not in filename and
'test_' not in filename):
filepath = os.path.join(dirpath, filename)
print(f" 发现测试文件: {filename}")
total_tests += 1
print(f" 开始测试文件 {total_tests}: {filename}")
if test_case_file(filepath, root_dir):
passed_tests += 1
print(f" 文件 {filename} 测试通过")
else:
failed_tests += 1
relative_path = os.path.relpath(filepath, root_dir)
failed_files.append(relative_path)
print(f" 文件 {filename} 测试失败")
else:
print(f" 跳过文件: {filename} (测试脚本或包含'test'关键字)")
print("\n" + "="*50)
print("TEST SUMMARY (STATIC SHAPE CASES ONLY)")
print("="*50)
print(f"Total tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {failed_tests}")
if failed_files:
print("\nFailed files:")
for file in failed_files:
print(f" - {file}")
return failed_tests == 0
def test_cases_from_file(file_list_path, root_dir):
"""Test cases listed in a file"""
print(f"\n=== 开始从文件列表测试用例 ===")
print(f" 文件列表路径: {file_list_path}")
print(f" 根目录: {root_dir}")
sys.path.insert(0, root_dir)
print(f" Python路径已更新")
print(f" 正在读取文件列表...")
with open(file_list_path, 'r') as f:
case_files = [line.strip() for line in f.readlines() if line.strip() and not line.startswith('#')]
print(f" 文件列表读取完成,共 {len(case_files)} 个文件")
total_tests = 0
passed_tests = 0
failed_tests = 0
failed_files = []
print(f" 开始测试文件列表中的用例...")
for i, case_file in enumerate(case_files):
print(f" 正在处理第 {i+1}/{len(case_files)} 个文件: {case_file}")
filepath = os.path.join(root_dir, case_file)
if os.path.exists(filepath):
print(f" 文件存在,开始测试...")
total_tests += 1
if test_case_file(filepath, root_dir):
passed_tests += 1
print(f" 文件 {case_file} 测试通过")
else:
failed_tests += 1
failed_files.append(case_file)
print(f" 文件 {case_file} 测试失败")
else:
print(f" 文件不存在: {filepath}")
failed_tests += 1
failed_files.append(case_file)
print("\n" + "="*50)
print("TEST SUMMARY")
print("="*50)
print(f"Total tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {failed_tests}")
if failed_files:
print("\nFailed files:")
for file in failed_files:
print(f" - {file}")
return failed_tests == 0
def test_new_cases(root_dir):
"""Test only new/modified cases by checking git status"""
print(f"\n=== 开始测试新修改的用例 ===")
print(f" 根目录: {root_dir}")
sys.path.insert(0, root_dir)
print(f" Python路径已更新")
try:
print(f" 正在检查git状态...")
import subprocess
result = subprocess.run(['git', 'status', '--porcelain', '*.py'],
cwd=root_dir, capture_output=True, text=True)
if result.returncode != 0:
print(f" git状态检查失败,返回码: {result.returncode}")
print(f" 错误输出: {result.stderr}")
return False
print(f" git状态检查成功")
lines = result.stdout.strip().split('\n')
print(f" 原始git输出行数: {len(lines)}")
modified_files = []
for i, line in enumerate(lines):
if line.strip():
print(f" 解析第 {i+1} 行: {line}")
parts = line.split()
if len(parts) >= 2:
filepath = parts[1] if parts[0] in ['M', 'A', '??'] else parts[0]
print(f" 解析文件路径: {filepath}")
if filepath.endswith('.py') and 'static_shape' in filepath:
modified_files.append(filepath)
print(f" 添加到修改文件列表: {filepath}")
else:
print(f" 跳过文件: {filepath} (非Python文件或不在static_shape目录)")
else:
print(f" 跳过无效行: {line}")
if not modified_files:
print(" 没有发现修改/新增的Python文件")
return True
print(f" 发现 {len(modified_files)} 个修改/新增的Python文件:")
for f in modified_files:
print(f" - {f}")
total_tests = 0
passed_tests = 0
failed_tests = 0
failed_files = []
print(f" 开始测试修改的文件...")
for i, case_file in enumerate(modified_files):
print(f" 正在测试第 {i+1}/{len(modified_files)} 个文件: {case_file}")
filepath = os.path.join(root_dir, case_file)
if os.path.exists(filepath):
print(f" 文件存在,开始测试...")
total_tests += 1
if test_case_file(filepath, root_dir):
passed_tests += 1
print(f" 文件 {case_file} 测试通过")
else:
failed_tests += 1
failed_files.append(case_file)
print(f" 文件 {case_file} 测试失败")
else:
print(f" 文件不存在: {filepath}")
failed_tests += 1
failed_files.append(case_file)
print("\n" + "="*50)
print("TEST SUMMARY (Modified/New Files Only)")
print("="*50)
print(f"Total tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {failed_tests}")
if failed_files:
print("\nFailed files:")
for file in failed_files:
print(f" - {file}")
return failed_tests == 0
except Exception as e:
print(f" 检查git状态时发生错误: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("=== 静态形状测试脚本启动 ===")
root_dir = os.path.dirname(os.path.abspath(__file__))
print(f" 脚本路径: {__file__}")
print(f" 根目录: {root_dir}")
print(f" 正在解析命令行参数...")
parser = argparse.ArgumentParser(description='Test akg_kernels_bench cases')
parser.add_argument('--file-list', '-f', help='Test cases listed in a file')
parser.add_argument('--new-only', '-n', action='store_true', help='Test only new/modified cases (using git status)')
args = parser.parse_args()
print(f" 命令行参数解析完成:")
print(f" --file-list: {args.file_list}")
print(f" --new-only: {args.new_only}")
if args.file_list:
print(f" 选择测试模式: 从文件列表测试")
success = test_cases_from_file(args.file_list, root_dir)
elif args.new_only:
print(f" 选择测试模式: 仅测试新修改的用例")
success = test_new_cases(root_dir)
else:
print(f" 选择测试模式: 测试所有静态形状用例")
print(f" 测试目录: {root_dir}")
success = test_all_cases(root_dir)
print(f"\n=== 测试脚本执行完成 ===")
print(f" 最终结果: {'成功' if success else '失败'}")
sys.exit(0 if success else 1)