"""
主程序入口
==========
AtomVision 物理轨迹恢复系统
"""
from __future__ import print_function
import sys
import os
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SCRIPTS_DIR = os.path.join(SCRIPT_DIR, 'scripts')
if SCRIPTS_DIR not in sys.path:
sys.path.insert(0, SCRIPTS_DIR)
try:
from data_loader import (
load_trajectory_data,
create_missing_trajectory,
get_trajectory_stats
)
from physics_fit import recover_trajectory_polyfit
from segment_fit import (
segment_fitting_pipeline,
recover_missing_by_segment
)
from evaluate import calculate_segment_errors
from visualize import (
plot_original_trajectory,
plot_trajectory_comparison,
plot_segmented_trajectory,
create_trajectory_animation
)
IMPORT_SUCCESS = True
except Exception as e:
IMPORT_SUCCESS = False
IMPORT_ERROR = str(e)
import numpy as np
def print_info(msg):
"""打印信息"""
print(msg)
sys.stdout.flush()
def setup_directories():
"""创建输出目录"""
global OUTPUT_DIR, FIGURES_DIR, ANIMATIONS_DIR, METRICS_DIR
OUTPUT_DIR = os.path.join(SCRIPT_DIR, 'outputs')
FIGURES_DIR = os.path.join(OUTPUT_DIR, 'figures')
ANIMATIONS_DIR = os.path.join(OUTPUT_DIR, 'animations')
METRICS_DIR = os.path.join(OUTPUT_DIR, 'metrics')
for d in [OUTPUT_DIR, FIGURES_DIR, ANIMATIONS_DIR, METRICS_DIR]:
if not os.path.exists(d):
os.makedirs(d)
def main():
"""主函数"""
print("\n" + "=" * 60)
print(" AtomVision 物理轨迹恢复系统")
print("=" * 60)
if not IMPORT_SUCCESS:
print("\n[错误] 模块导入失败!")
print("错误信息:", IMPORT_ERROR)
return
print("\n[初始化] 创建输出目录...")
setup_directories()
print_info("[完成]")
print("\n[1/7] 加载轨迹数据...")
data_paths = [
os.path.join(SCRIPT_DIR, '..', '..', 'BPT-V', 'dataset', 'video_001', 'data.json'),
os.path.join(SCRIPT_DIR, '..', 'BPT-V', 'dataset', 'video_001', 'data.json'),
r"E:\VScode\项目\物理比赛\BPT-V\dataset\video_001\data.json",
r"E:\VScode\项目\物理比赛\BPT-V\dataset\video_001\data.json",
]
data_path = None
for path in data_paths:
if os.path.exists(path):
data_path = path
break
if data_path is None:
print("[错误] 数据文件不存在!")
print("已尝试的路径:")
for p in data_paths:
print(" -", p)
return
print_info(" 数据路径: " + data_path)
try:
frame_ids, x_coords, y_coords = load_trajectory_data(data_path)
stats = get_trajectory_stats(frame_ids, x_coords, y_coords)
print_info(" [成功] 加载完成")
print_info(" - 总帧数: " + str(stats['total_frames']))
print_info(" - X范围: {0:.1f} ~ {1:.1f}".format(*stats['x_range']))
print_info(" - Y范围: {0:.1f} ~ {1:.1f}".format(*stats['y_range']))
except Exception as e:
print("[错误] 数据加载失败:", str(e))
return
MISSING_START = 30
MISSING_END = 50
print("\n[2/7] 制造轨迹缺失 (第" + str(MISSING_START) + "-" + str(MISSING_END) + "帧)...")
missing_frame_ids, missing_x, missing_y = create_missing_trajectory(
frame_ids, x_coords, y_coords,
MISSING_START, MISSING_END
)
print_info(" [成功] 缺失点数: " + str(len(frame_ids) - len(missing_frame_ids)))
print("\n[3/7] 全局二次拟合...")
global_recovered_x, global_recovered_y, global_coeffs = recover_trajectory_polyfit(
frame_ids, x_coords, y_coords,
MISSING_START, MISSING_END
)
original_missing_mask = (frame_ids >= MISSING_START) & (frame_ids <= MISSING_END)
original_missing_x = x_coords[original_missing_mask]
original_missing_y = y_coords[original_missing_mask]
global_errors = calculate_segment_errors(
original_missing_x, original_missing_y,
global_recovered_x, global_recovered_y
)
print_info(" [成功] 全局MSE: {0:.4f}".format(global_errors['mse_combined']))
print("\n[4/7] 分段拟合与反弹检测...")
segment_result = segment_fitting_pipeline(frame_ids, x_coords, y_coords)
bounce_indices = segment_result['bounce_indices']
segments = segment_result['segments']
segment_results = segment_result['segment_results']
print_info(" [成功] 反弹点: " + str(len(bounce_indices)))
print_info(" [成功] 分段数: " + str(len(segments)))
segment_recovery = recover_missing_by_segment(
frame_ids, x_coords, y_coords,
MISSING_START, MISSING_END,
segment_results, segments
)
print_info(" [成功] 分段MSE: {0:.4f}".format(segment_recovery['mse_total']))
print("\n[5/7] 生成静态图表...")
plot_original_trajectory(
x_coords, y_coords, frame_ids,
save_path=os.path.join(FIGURES_DIR, 'original_trajectory.png')
)
print_info(" [1/3] 原始轨迹图")
plot_trajectory_comparison(
x_coords, y_coords, missing_x, missing_y,
global_recovered_x, global_recovered_y, frame_ids,
MISSING_START, MISSING_END,
save_path=os.path.join(FIGURES_DIR, 'trajectory_comparison.png')
)
print_info(" [2/3] 轨迹对比图")
plot_segmented_trajectory(
x_coords, y_coords, segments, bounce_indices,
frame_ids, segment_recovery['recovered_x'], segment_recovery['recovered_y'],
MISSING_START, MISSING_END,
{'segment_results': segment_results, 'mse_total': segment_recovery['mse_total']},
save_path=os.path.join(FIGURES_DIR, 'segmented_trajectory.png')
)
print_info(" [3/3] 分段轨迹图")
print("\n[6/7] 生成动画...")
create_trajectory_animation(
frame_ids, x_coords, y_coords,
missing_x, missing_y,
segment_recovery['recovered_x'], segment_recovery['recovered_y'],
MISSING_START, MISSING_END,
fps=15,
save_path=os.path.join(ANIMATIONS_DIR, 'trajectory_animation.gif')
)
print("\n[7/7] 保存误差报告...")
metrics_path = os.path.join(METRICS_DIR, 'metrics.txt')
with open(metrics_path, 'w', encoding='utf-8') as f:
f.write("=" * 50 + "\n")
f.write("轨迹恢复误差分析报告\n")
f.write("=" * 50 + "\n\n")
f.write("【反弹点检测】\n")
f.write("检测到反弹点数量: " + str(len(bounce_indices)) + "\n")
f.write("反弹点位置(帧): " + str([int(i) for i in bounce_indices]) + "\n\n")
f.write("【误差分析】\n")
f.write("全局拟合 MSE: {0:.4f} 像素^2\n".format(global_errors['mse_combined']))
f.write("分段拟合 MSE: {0:.4f} 像素^2\n".format(segment_recovery['mse_total']))
improvement = (global_errors['mse_combined'] - segment_recovery['mse_total']) / global_errors['mse_combined'] * 100
f.write("改进幅度: {0:.2f}%\n".format(improvement))
print_info(" [成功] 报告已保存")
print("\n" + "=" * 60)
print(" 执行结果汇总")
print("=" * 60)
print("\n【反弹点检测】")
print(" 检测到反弹点: " + str(len(bounce_indices)) + " 个")
print(" 反弹点位置(帧): " + str([int(i) for i in bounce_indices]))
print("\n【误差分析】")
print(" 全局拟合 MSE: {0:.4f} 像素^2".format(global_errors['mse_combined']))
print(" 分段拟合 MSE: {0:.4f} 像素^2".format(segment_recovery['mse_total']))
improvement = (global_errors['mse_combined'] - segment_recovery['mse_total']) / global_errors['mse_combined'] * 100
if improvement > 0:
print(" 分段拟合误差降低: {0:.2f}%".format(improvement))
print("\n【输出文件】")
print(" 原始轨迹图: " + os.path.join(FIGURES_DIR, 'original_trajectory.png'))
print(" 轨迹对比图: " + os.path.join(FIGURES_DIR, 'trajectory_comparison.png'))
print(" 分段轨迹图: " + os.path.join(FIGURES_DIR, 'segmented_trajectory.png'))
print(" 轨迹动画: " + os.path.join(ANIMATIONS_DIR, 'trajectory_animation.gif'))
print(" 误差报告: " + metrics_path)
print("\n" + "=" * 60)
print(" 执行完毕!")
print("=" * 60)
if __name__ == "__main__":
main()