# -*- coding: utf-8 -*-
"""
主程序入口
==========

AtomVision 物理轨迹恢复系统
"""

from __future__ import print_function
import sys
import os

# 获取脚本所在目录
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

# 添加 scripts 目录到路径
SCRIPTS_DIR = os.path.join(SCRIPT_DIR, 'scripts')
if SCRIPTS_DIR not in sys.path:
    sys.path.insert(0, SCRIPTS_DIR)

# 导入模块
try:
    from data_loader import (
        load_trajectory_data,
        create_missing_trajectory,
        get_trajectory_stats
    )
    from physics_fit import recover_trajectory_polyfit
    from segment_fit import (
        segment_fitting_pipeline,
        recover_missing_by_segment
    )
    from evaluate import calculate_segment_errors
    from visualize import (
        plot_original_trajectory,
        plot_trajectory_comparison,
        plot_segmented_trajectory,
        create_trajectory_animation
    )
    IMPORT_SUCCESS = True
except Exception as e:
    IMPORT_SUCCESS = False
    IMPORT_ERROR = str(e)

import numpy as np


def print_info(msg):
    """打印信息"""
    print(msg)
    sys.stdout.flush()


def setup_directories():
    """创建输出目录"""
    global OUTPUT_DIR, FIGURES_DIR, ANIMATIONS_DIR, METRICS_DIR

    OUTPUT_DIR = os.path.join(SCRIPT_DIR, 'outputs')
    FIGURES_DIR = os.path.join(OUTPUT_DIR, 'figures')
    ANIMATIONS_DIR = os.path.join(OUTPUT_DIR, 'animations')
    METRICS_DIR = os.path.join(OUTPUT_DIR, 'metrics')

    for d in [OUTPUT_DIR, FIGURES_DIR, ANIMATIONS_DIR, METRICS_DIR]:
        if not os.path.exists(d):
            os.makedirs(d)


def main():
    """主函数"""

    print("\n" + "=" * 60)
    print("    AtomVision 物理轨迹恢复系统")
    print("=" * 60)

    # 检查导入
    if not IMPORT_SUCCESS:
        print("\n[错误] 模块导入失败!")
        print("错误信息:", IMPORT_ERROR)
        return

    # 设置输出目录
    print("\n[初始化] 创建输出目录...")
    setup_directories()
    print_info("[完成]")

    # 数据路径 - 使用原始路径字符串避免编码问题
    print("\n[1/7] 加载轨迹数据...")

    # 尝试多个可能的路径
    data_paths = [
        os.path.join(SCRIPT_DIR, '..', '..', 'BPT-V', 'dataset', 'video_001', 'data.json'),
        os.path.join(SCRIPT_DIR, '..', 'BPT-V', 'dataset', 'video_001', 'data.json'),
        r"E:\VScode\项目\物理比赛\BPT-V\dataset\video_001\data.json",
        r"E:\VScode\项目\物理比赛\BPT-V\dataset\video_001\data.json",
    ]

    data_path = None
    for path in data_paths:
        if os.path.exists(path):
            data_path = path
            break

    if data_path is None:
        print("[错误] 数据文件不存在!")
        print("已尝试的路径:")
        for p in data_paths:
            print("  -", p)
        return

    print_info("  数据路径: " + data_path)

    try:
        frame_ids, x_coords, y_coords = load_trajectory_data(data_path)
        stats = get_trajectory_stats(frame_ids, x_coords, y_coords)

        print_info("  [成功] 加载完成")
        print_info("  - 总帧数: " + str(stats['total_frames']))
        print_info("  - X范围: {0:.1f} ~ {1:.1f}".format(*stats['x_range']))
        print_info("  - Y范围: {0:.1f} ~ {1:.1f}".format(*stats['y_range']))
    except Exception as e:
        print("[错误] 数据加载失败:", str(e))
        return

    # 配置
    MISSING_START = 30
    MISSING_END = 50

    print("\n[2/7] 制造轨迹缺失 (第" + str(MISSING_START) + "-" + str(MISSING_END) + "帧)...")

    missing_frame_ids, missing_x, missing_y = create_missing_trajectory(
        frame_ids, x_coords, y_coords,
        MISSING_START, MISSING_END
    )

    print_info("  [成功] 缺失点数: " + str(len(frame_ids) - len(missing_frame_ids)))

    print("\n[3/7] 全局二次拟合...")

    global_recovered_x, global_recovered_y, global_coeffs = recover_trajectory_polyfit(
        frame_ids, x_coords, y_coords,
        MISSING_START, MISSING_END
    )

    original_missing_mask = (frame_ids >= MISSING_START) & (frame_ids <= MISSING_END)
    original_missing_x = x_coords[original_missing_mask]
    original_missing_y = y_coords[original_missing_mask]

    global_errors = calculate_segment_errors(
        original_missing_x, original_missing_y,
        global_recovered_x, global_recovered_y
    )

    print_info("  [成功] 全局MSE: {0:.4f}".format(global_errors['mse_combined']))

    print("\n[4/7] 分段拟合与反弹检测...")

    segment_result = segment_fitting_pipeline(frame_ids, x_coords, y_coords)

    bounce_indices = segment_result['bounce_indices']
    segments = segment_result['segments']
    segment_results = segment_result['segment_results']

    print_info("  [成功] 反弹点: " + str(len(bounce_indices)))
    print_info("  [成功] 分段数: " + str(len(segments)))

    segment_recovery = recover_missing_by_segment(
        frame_ids, x_coords, y_coords,
        MISSING_START, MISSING_END,
        segment_results, segments
    )

    print_info("  [成功] 分段MSE: {0:.4f}".format(segment_recovery['mse_total']))

    print("\n[5/7] 生成静态图表...")

    plot_original_trajectory(
        x_coords, y_coords, frame_ids,
        save_path=os.path.join(FIGURES_DIR, 'original_trajectory.png')
    )
    print_info("  [1/3] 原始轨迹图")

    plot_trajectory_comparison(
        x_coords, y_coords, missing_x, missing_y,
        global_recovered_x, global_recovered_y, frame_ids,
        MISSING_START, MISSING_END,
        save_path=os.path.join(FIGURES_DIR, 'trajectory_comparison.png')
    )
    print_info("  [2/3] 轨迹对比图")

    plot_segmented_trajectory(
        x_coords, y_coords, segments, bounce_indices,
        frame_ids, segment_recovery['recovered_x'], segment_recovery['recovered_y'],
        MISSING_START, MISSING_END,
        {'segment_results': segment_results, 'mse_total': segment_recovery['mse_total']},
        save_path=os.path.join(FIGURES_DIR, 'segmented_trajectory.png')
    )
    print_info("  [3/3] 分段轨迹图")

    print("\n[6/7] 生成动画...")

    create_trajectory_animation(
        frame_ids, x_coords, y_coords,
        missing_x, missing_y,
        segment_recovery['recovered_x'], segment_recovery['recovered_y'],
        MISSING_START, MISSING_END,
        fps=15,
        save_path=os.path.join(ANIMATIONS_DIR, 'trajectory_animation.gif')
    )

    print("\n[7/7] 保存误差报告...")

    metrics_path = os.path.join(METRICS_DIR, 'metrics.txt')
    with open(metrics_path, 'w', encoding='utf-8') as f:
        f.write("=" * 50 + "\n")
        f.write("轨迹恢复误差分析报告\n")
        f.write("=" * 50 + "\n\n")

        f.write("【反弹点检测】\n")
        f.write("检测到反弹点数量: " + str(len(bounce_indices)) + "\n")
        f.write("反弹点位置(帧): " + str([int(i) for i in bounce_indices]) + "\n\n")

        f.write("【误差分析】\n")
        f.write("全局拟合 MSE: {0:.4f} 像素^2\n".format(global_errors['mse_combined']))
        f.write("分段拟合 MSE: {0:.4f} 像素^2\n".format(segment_recovery['mse_total']))

        improvement = (global_errors['mse_combined'] - segment_recovery['mse_total']) / global_errors['mse_combined'] * 100
        f.write("改进幅度: {0:.2f}%\n".format(improvement))

    print_info("  [成功] 报告已保存")

    # 结果汇总
    print("\n" + "=" * 60)
    print("    执行结果汇总")
    print("=" * 60)

    print("\n【反弹点检测】")
    print("  检测到反弹点: " + str(len(bounce_indices)) + " 个")
    print("  反弹点位置(帧): " + str([int(i) for i in bounce_indices]))

    print("\n【误差分析】")
    print("  全局拟合 MSE: {0:.4f} 像素^2".format(global_errors['mse_combined']))
    print("  分段拟合 MSE: {0:.4f} 像素^2".format(segment_recovery['mse_total']))

    improvement = (global_errors['mse_combined'] - segment_recovery['mse_total']) / global_errors['mse_combined'] * 100
    if improvement > 0:
        print("  分段拟合误差降低: {0:.2f}%".format(improvement))

    print("\n【输出文件】")
    print("  原始轨迹图: " + os.path.join(FIGURES_DIR, 'original_trajectory.png'))
    print("  轨迹对比图: " + os.path.join(FIGURES_DIR, 'trajectory_comparison.png'))
    print("  分段轨迹图: " + os.path.join(FIGURES_DIR, 'segmented_trajectory.png'))
    print("  轨迹动画:   " + os.path.join(ANIMATIONS_DIR, 'trajectory_animation.gif'))
    print("  误差报告:   " + metrics_path)

    print("\n" + "=" * 60)
    print("    执行完毕!")
    print("=" * 60)


if __name__ == "__main__":
    main()