"""
PyPTO 精度变化折线图绘制工具
从验证结果日志文件中提取精度数据并绘制变化趋势图
"""
import re
import sys
import os
import logging
import matplotlib.pyplot as plt
import numpy as np
logging.basicConfig(
level=logging.INFO,
format='%(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
def extract_operator_name(log_file_path):
"""从日志文件路径中提取算子名称"""
log_file_name = os.path.basename(log_file_path)
if log_file_name.endswith('_verify_result.log'):
return log_file_name[:-len('_verify_result.log')]
return 'operator'
def parse_checkpoints(lines):
"""从日志行中提取检查点名称和对应的精度数据"""
checkpoint_pattern = r'^(\d+_[^:\s]+):'
checkpoints = []
checkpoint_line_indices = []
for i, line in enumerate(lines):
match = re.match(checkpoint_pattern, line)
if match:
checkpoint_name = match.group(1)
checkpoints.append(checkpoint_name)
checkpoint_line_indices.append(i)
results = []
for i, checkpoint in enumerate(checkpoints):
start_line = checkpoint_line_indices[i]
tolerance_rtol = None
tolerance_atol = None
actual_rtol = None
actual_atol = None
for j in range(start_line, min(start_line + 20, len(lines))):
if 'Tolerance: rtol=' in lines[j]:
num_pattern = r'([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)'
tolerance_pattern = f'Tolerance: rtol={num_pattern}, atol={num_pattern}'
tolerance_match = re.search(tolerance_pattern, lines[j])
if tolerance_match:
tolerance_rtol = float(tolerance_match.group(1))
tolerance_atol = float(tolerance_match.group(2))
elif 'Actual: rtol=' in lines[j]:
num_pattern = r'([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)'
actual_pattern = f'Actual: rtol={num_pattern}, atol={num_pattern}'
actual_match = re.search(actual_pattern, lines[j])
if actual_match:
actual_rtol = float(actual_match.group(1))
actual_atol = float(actual_match.group(2))
break
if actual_rtol is not None and actual_atol is not None:
results.append((checkpoint, tolerance_rtol, tolerance_atol, actual_rtol, actual_atol))
return results
def plot_accuracy(results, output_path):
"""绘制精度变化折线图"""
checkpoints_list = [r[0] for r in results]
tol_rtol_values = [r[1] for r in results]
tol_atol_values = [r[2] for r in results]
act_rtol_values = [r[3] for r in results]
act_atol_values = [r[4] for r in results]
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
ax1.plot(range(len(checkpoints_list)), act_rtol_values, marker='o', linewidth=2, markersize=8,
color='blue', label='Actual rtol')
ax1.plot(range(len(checkpoints_list)), tol_rtol_values, marker='s', linewidth=2, markersize=6,
color='green', linestyle='--', label='Tolerance rtol')
ax1.set_xlabel('Checkpoint', fontsize=14)
ax1.set_ylabel('Relative Tolerance (rtol)', fontsize=14)
ax1.set_title('Relative Tolerance (rtol) Change Across Checkpoints', fontsize=16, fontweight='bold', y=1.02)
ax1.grid(True, alpha=0.3)
ax1.set_xticks(range(len(checkpoints_list)))
ax1.set_xticklabels(checkpoints_list, rotation=45, ha='right', fontsize=11)
ax1.legend(loc='upper right', fontsize=11)
for i, (ckpt, rtol) in enumerate(zip(checkpoints_list, act_rtol_values)):
if rtol > 0:
ax1.annotate(f'{rtol:.4f}', (i, rtol), textcoords="offset points",
xytext=(0, 10), ha='center', fontsize=8, color='blue', fontweight='bold')
ax2.plot(range(len(checkpoints_list)), act_atol_values, marker='o', linewidth=2, markersize=8,
color='red', label='Actual atol')
ax2.plot(range(len(checkpoints_list)), tol_atol_values, marker='s', linewidth=2, markersize=6,
color='green', linestyle='--', label='Tolerance atol')
ax2.set_xlabel('Checkpoint', fontsize=14)
ax2.set_ylabel('Absolute Tolerance (atol)', fontsize=14)
ax2.set_title('Absolute Tolerance (atol) Change Across Checkpoints', fontsize=16, fontweight='bold', y=1.02)
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(len(checkpoints_list)))
ax2.set_xticklabels(checkpoints_list, rotation=45, ha='right', fontsize=11)
ax2.legend(loc='upper right', fontsize=11)
for i, (ckpt, atol) in enumerate(zip(checkpoints_list, act_atol_values)):
if atol > 0:
ax2.annotate(f'{atol:.2f}', (i, atol), textcoords="offset points",
xytext=(0, 10), ha='center', fontsize=8, color='red', fontweight='bold')
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
return output_path
def print_summary(results):
"""打印精度数据汇总"""
logger.info("\n精度数据汇总:")
logger.info("-" * 100)
logger.info(f"{'Checkpoint':<30} {'Tolerance':<25} {'Actual':<25} {'Status':<10}")
logger.info("-" * 100)
for ckpt, tol_rtol, tol_atol, act_rtol, act_atol in results:
status = "FAIL" if (act_rtol > tol_rtol or act_atol > tol_atol) else "PASS"
tol_str = f"rt={tol_rtol:.6f}, at={tol_atol:.6f}"
act_str = f"rt={act_rtol:.6f}, at={act_atol:.6f}"
logger.info(f"{ckpt:<30} {tol_str:<25} {act_str:<25} {status:<10}")
logger.info("-" * 100)
def main():
"""主函数"""
if len(sys.argv) < 2:
logger.error("错误: 必须传入 log 文件路径作为参数")
logger.error("用法: python3 plot_accuracy.py <verify_result.log>")
sys.exit(1)
log_file = sys.argv[1]
log_file_path = os.path.abspath(log_file)
log_file_dir = os.path.dirname(log_file_path)
operator_name = extract_operator_name(log_file_path)
with open(log_file, 'r') as f:
lines = f.readlines()
results = parse_checkpoints(lines)
logger.info(f'提取到 {len(results)} 个检查点的精度数据')
for i, (ckpt, tol_rtol, tol_atol, act_rtol, act_atol) in enumerate(results):
tol_str = f'Tolerance(rt={tol_rtol:.6f}, at={tol_atol:.6f})'
act_str = f'Actual(rt={act_rtol:.6f}, at={act_atol:.6f})'
logger.info(f'{i+1}. {ckpt}: {tol_str} | {act_str}')
output_image = os.path.join(log_file_dir, f'{operator_name}_accuracy_change.png')
plot_accuracy(results, output_image)
logger.info(f"\n精度变化折线图已保存至: {output_image}")
print_summary(results)
if __name__ == "__main__":
main()