import os
import sys
import importlib.util
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Test case timeout (60s)")
def test_single_case(filepath, timeout=60):
"""Test a single Python case file with timeout"""
print(f"\n=== 开始测试单个文件: {filepath} ===")
root_dir = os.path.dirname(os.path.abspath(filepath))
print(f" 根目录: {root_dir}")
sys.path.insert(0, root_dir)
print(f" Python路径已更新")
print(" 正在检测可用设备...")
import torch
device = None
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
print(f" Using GPU: {torch.cuda.get_device_name(device)}")
elif hasattr(torch, 'npu') and torch.npu.is_available():
device_id = os.environ.get('DEVICE_ID', '0')
os.environ['DEVICE_ID'] = str(device_id)
device = torch.device("npu")
torch.npu.manual_seed(0)
torch.npu.set_device(int(device_id))
print(f" Using NPU: device {device_id}")
else:
device = torch.device("cpu")
print(f" Using CPU (GPU/NPU not available)")
relative_path = os.path.relpath(filepath, root_dir)
module_path = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '')
print(f" 相对路径: {relative_path}")
print(f" 模块路径: {module_path}")
print(f" 设置超时: {timeout}秒")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
print(f" 正在导入模块...")
spec = importlib.util.spec_from_file_location(module_path, filepath)
print(f" 模块spec创建成功")
module = importlib.util.module_from_spec(spec)
print(f" 模块对象创建成功")
spec.loader.exec_module(module)
print(f" 模块执行成功")
print(f" 正在获取Model类...")
model_class = getattr(module, 'Model')
print(f" Model类获取成功: {model_class}")
print(f" 正在获取get_inputs函数...")
get_inputs_func = getattr(module, 'get_inputs')
print(f" get_inputs函数获取成功: {get_inputs_func}")
print(f" 正在获取get_init_inputs函数...")
get_init_inputs_func = getattr(module, 'get_init_inputs')
print(f" get_init_inputs函数获取成功: {get_init_inputs_func}")
print(f" 正在获取初始化参数...")
init_params = get_init_inputs_func()
print(f" 初始化参数: {init_params}")
print(f" 正在初始化模型...")
if isinstance(init_params, list):
model = model_class(*init_params)
else:
model = model_class()
print(f" 模型初始化成功: {type(model)}")
if device.type != "cpu":
print(f" 正在将模型移动到设备: {device}")
model = model.to(device)
print(f" 模型移动成功")
print(f" 正在获取输入数据...")
inputs = get_inputs_func()
print(f" 输入数据获取成功: {type(inputs)}")
print(f" 输入长度: {len(inputs) if isinstance(inputs, list) else 1}")
print(f" 正在处理输入数据...")
if isinstance(inputs, list):
print(f" 输入是列表,长度: {len(inputs)}")
device_inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]
print(f" 输入数据移动到设备完成")
print(f" 正在运行模型...")
output = model(*device_inputs)
else:
print(f" 输入是单个张量")
device_inputs = inputs.to(device) if isinstance(inputs, torch.Tensor) else inputs
print(f" 输入数据移动到设备完成")
print(f" 正在运行模型...")
output = model(device_inputs)
print(f" === 文件 {relative_path} 测试完成: PASSED ===")
return True
except TimeoutError:
print(f" === 文件 {relative_path} 测试超时: {timeout}秒 ===")
return False
except Exception as e:
print(f" === 文件 {relative_path} 测试失败: {str(e)} ===")
import traceback
traceback.print_exc()
return False
finally:
signal.alarm(0)
print(f" 超时定时器已取消")
if __name__ == "__main__":
print("=== 静态形状单个用例测试脚本启动 ===")
if len(sys.argv) < 2 or len(sys.argv) > 3:
print("用法: python3 test_single_case.py <path_to_case_file.py> [timeout_seconds]")
print(f" 当前参数数量: {len(sys.argv)}")
print(f" 参数列表: {sys.argv}")
sys.exit(1)
filepath = sys.argv[1]
print(f" 目标文件路径: {filepath}")
if not os.path.exists(filepath):
print(f" 文件不存在: {filepath}")
sys.exit(1)
print(f" 文件存在,开始检查...")
timeout = 60
if len(sys.argv) == 3:
try:
timeout = int(sys.argv[2])
print(f" 自定义超时时间: {timeout}秒")
except ValueError:
print(f" 超时时间必须是整数: {sys.argv[2]}")
sys.exit(1)
else:
print(f" 使用默认超时时间: {timeout}秒")
print(f" 开始测试文件: {filepath}")
print(f" 超时设置: {timeout}秒")
success = test_single_case(filepath, timeout)
print(f"\n=== 单个用例测试完成 ===")
print(f" 测试结果: {'成功' if success else '失败'}")
sys.exit(0 if success else 1)