# -*- coding: utf-8 -*-
"""用于多先验融合、事件监督和不确定性训练的数据集构建。"""

from __future__ import annotations

import json
import os
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset


def _list_video_jsons(dataset_dir: str) -> List[Tuple[int, str]]:
    pat = re.compile(r"^video_(\d+)$")
    out: List[Tuple[int, str]] = []
    for name in os.listdir(dataset_dir):
        full = os.path.join(dataset_dir, name)
        m = pat.match(name)
        if m and os.path.isdir(full):
            p = os.path.join(full, "data.json")
            if os.path.exists(p):
                out.append((int(m.group(1)), p))
    out.sort(key=lambda x: x[0])
    return out


def _load_frames(json_path: str) -> Dict[str, np.ndarray]:
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    frames = data["frames"]
    frame_ids = np.array([it["frame_id"] for it in frames], dtype=np.int64)
    x = np.array([it["x"] for it in frames], dtype=np.float32)
    y = np.array([it["y"] for it in frames], dtype=np.float32)
    fps = float(data.get("metadata", {}).get("fps", 30.0))
    return {"frame_ids": frame_ids, "x": x, "y": y, "fps": fps}


def _safe_missing_interval(frame_ids: np.ndarray, start: int, end: int) -> Optional[Tuple[int, int]]:
    min_f = int(frame_ids.min())
    max_f = int(frame_ids.max())
    s = max(min_f + 1, int(start))
    e = min(max_f - 1, int(end))
    if s >= e:
        return None
    return s, e


def _compute_common_confidence(fit_rmse: float, miss_ratio: float, x: np.ndarray, y: np.ndarray) -> float:
    traj_scale = float(np.std(np.stack([x, y], axis=-1)) + 1e-6)
    fit_score = np.exp(-float(fit_rmse) / max(traj_scale, 1e-6))
    len_score = np.exp(-2.0 * miss_ratio)
    return float(np.clip(fit_score * len_score, 0.0, 1.0))


def _local_velocity(frame_ids: np.ndarray, x: np.ndarray, y: np.ndarray, idx0: int, idx1: int) -> Tuple[float, float]:
    dt = float(frame_ids[idx1] - frame_ids[idx0])
    dt = dt if abs(dt) > 1e-6 else 1.0
    return float((x[idx1] - x[idx0]) / dt), float((y[idx1] - y[idx0]) / dt)


def _fit_poly_prior(
    frame_ids: np.ndarray,
    x: np.ndarray,
    y: np.ndarray,
    missing_start: int,
    missing_end: int,
    deg: int,
    edge_weighted: bool,
):
    valid_mask = (frame_ids < missing_start) | (frame_ids > missing_end)
    vf = frame_ids[valid_mask]
    vx = x[valid_mask]
    vy = y[valid_mask]
    if len(vf) < (deg + 1):
        return None, None

    if edge_weighted:
        left_dist = np.abs(vf - float(missing_start))
        right_dist = np.abs(vf - float(missing_end))
        edge_dist = np.minimum(left_dist, right_dist)
        dist_scale = float(np.percentile(edge_dist, 60) + 1.0)
        w = np.exp(-edge_dist / max(dist_scale, 1e-6)).astype(np.float64)
        w = 0.5 + 1.5 * (w / max(float(np.max(w)), 1e-8))
    else:
        w = np.ones_like(vf, dtype=np.float64)

    cx = np.polyfit(vf, vx, deg=deg, w=w)
    cy = np.polyfit(vf, vy, deg=deg, w=w)

    rec_f = np.arange(missing_start, missing_end + 1, dtype=np.int64)
    px = np.poly1d(cx)(rec_f).astype(np.float32)
    py = np.poly1d(cy)(rec_f).astype(np.float32)

    fit_x = np.poly1d(cx)(vf)
    fit_y = np.poly1d(cy)(vf)
    fit_rmse = float(np.sqrt(np.mean((fit_x - vx) ** 2 + (fit_y - vy) ** 2)))
    return np.stack([px, py], axis=-1), fit_rmse


def _fit_constant_velocity_bridge(
    frame_ids: np.ndarray,
    x: np.ndarray,
    y: np.ndarray,
    missing_start: int,
    missing_end: int,
):
    left_idx = np.where(frame_ids < missing_start)[0]
    right_idx = np.where(frame_ids > missing_end)[0]
    if len(left_idx) == 0 or len(right_idx) == 0:
        return None, None

    li = int(left_idx[-1])
    ri = int(right_idx[0])
    if li <= 0 and ri >= len(frame_ids) - 1:
        return None, None

    # Left local velocity
    if li >= 1:
        dt_l = float(frame_ids[li] - frame_ids[li - 1])
        dt_l = dt_l if abs(dt_l) > 1e-6 else 1.0
        vx_l = float((x[li] - x[li - 1]) / dt_l)
        vy_l = float((y[li] - y[li - 1]) / dt_l)
    else:
        vx_l = 0.0
        vy_l = 0.0

    # Right local velocity (time-forward)
    if ri + 1 < len(frame_ids):
        dt_r = float(frame_ids[ri + 1] - frame_ids[ri])
        dt_r = dt_r if abs(dt_r) > 1e-6 else 1.0
        vx_r = float((x[ri + 1] - x[ri]) / dt_r)
        vy_r = float((y[ri + 1] - y[ri]) / dt_r)
    else:
        vx_r = vx_l
        vy_r = vy_l

    rec_f = np.arange(missing_start, missing_end + 1, dtype=np.float32)
    t0_l = float(frame_ids[li])
    t0_r = float(frame_ids[ri])

    px_l = float(x[li]) + vx_l * (rec_f - t0_l)
    py_l = float(y[li]) + vy_l * (rec_f - t0_l)
    px_r = float(x[ri]) - vx_r * (t0_r - rec_f)
    py_r = float(y[ri]) - vy_r * (t0_r - rec_f)

    if len(rec_f) <= 1:
        blend = np.zeros_like(rec_f)
    else:
        blend = (rec_f - rec_f.min()) / max(float(rec_f.max() - rec_f.min()), 1e-6)

    px = (1.0 - blend) * px_l + blend * px_r
    py = (1.0 - blend) * py_l + blend * py_r
    prior = np.stack([px.astype(np.float32), py.astype(np.float32)], axis=-1)

    vel_gap = np.sqrt((vx_l - vx_r) ** 2 + (vy_l - vy_r) ** 2)
    fit_rmse_proxy = float(vel_gap)
    return prior, fit_rmse_proxy


def _build_event_targets(target_xy: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    t = target_xy.shape[0]
    labels = np.zeros((t, 1), dtype=np.float32)
    weights = np.ones((t, 1), dtype=np.float32)
    if t < 5:
        return labels, weights

    py = target_xy[:, 1]
    vy = np.gradient(py)
    ay = np.gradient(vy)
    abs_ay = np.abs(ay)
    abs_vy = np.abs(vy)

    # Dual-criterion:
    # 1) turning-point peak in y with sign reversal in velocity
    # 2) high acceleration impulse around the same time
    peak_thr = float(np.percentile(abs_vy, 40)) if t > 6 else float(np.mean(abs_vy))
    acc_thr = float(np.percentile(abs_ay, 70)) if t > 6 else float(np.mean(abs_ay))

    cand = []
    for i in range(1, t - 1):
        is_local_peak_or_valley = (py[i] >= py[i - 1] and py[i] >= py[i + 1]) or (
            py[i] <= py[i - 1] and py[i] <= py[i + 1]
        )
        sign_flip = vy[i - 1] * vy[i + 1] <= 0.0
        strong_acc = abs_ay[i] >= acc_thr
        enough_speed_context = max(abs(vy[i - 1]), abs(vy[i + 1])) >= peak_thr
        if is_local_peak_or_valley and sign_flip and strong_acc and enough_speed_context:
            cand.append(i)

    if not cand:
        return labels, weights

    # Neighborhood expansion (adaptive with sequence length).
    radius = 1 if t <= 21 else 2
    for i in cand:
        l = max(0, i - radius)
        r = min(t, i + radius + 1)
        labels[l:r, 0] = 1.0
        for j in range(l, r):
            dist = abs(j - i)
            weights[j, 0] = max(weights[j, 0], 1.0 + 2.0 * (1.0 - dist / max(radius + 1, 1)))
    return labels, weights


@dataclass
class Sample:
    feats: np.ndarray  # [T, C]
    prior: np.ndarray  # [T, 2] (confidence-mixed prior)
    priors: np.ndarray  # [T, K, 2] multi-prior trajectories
    prior_confidences: np.ndarray  # [K]
    target: np.ndarray  # [T, 2]
    event_labels: np.ndarray  # [T, 1]
    event_weights: np.ndarray  # [T, 1]
    video_id: int
    prior_mse: float
    prior_confidence: float
    missing_ratio: float
    event_density: float
    difficulty_score: float


def build_samples(
    dataset_dir: str,
    video_ids: Sequence[int],
    missing_start: int = 30,
    missing_end: int = 50,
) -> List[Sample]:
    all_jsons = dict(_list_video_jsons(dataset_dir))
    samples: List[Sample] = []
    for vid in video_ids:
        if vid not in all_jsons:
            continue
        item = _load_frames(all_jsons[vid])
        frame_ids = item["frame_ids"]
        x = item["x"]
        y = item["y"]
        fps = item["fps"]

        interval = _safe_missing_interval(frame_ids, missing_start, missing_end)
        if interval is None:
            continue
        s, e = interval

        p_quad, rmse_quad = _fit_poly_prior(frame_ids, x, y, s, e, deg=2, edge_weighted=True)
        p_lin, rmse_lin = _fit_poly_prior(frame_ids, x, y, s, e, deg=1, edge_weighted=False)
        p_cv, rmse_cv = _fit_constant_velocity_bridge(frame_ids, x, y, s, e)
        if p_quad is None:
            continue
        if p_lin is None:
            p_lin = p_quad.copy()
            rmse_lin = rmse_quad
        if p_cv is None:
            p_cv = p_lin.copy()
            rmse_cv = rmse_lin

        gt_mask = (frame_ids >= s) & (frame_ids <= e)
        gt_xy = np.stack([x[gt_mask], y[gt_mask]], axis=-1).astype(np.float32)
        if len(gt_xy) != len(p_quad):
            continue

        miss_ratio = float(len(gt_xy) / max(1, len(frame_ids)))
        conf_quad = _compute_common_confidence(float(rmse_quad), miss_ratio, x, y)
        conf_lin = _compute_common_confidence(float(rmse_lin), miss_ratio, x, y)

        traj_scale = float(np.std(np.stack([x, y], axis=-1)) + 1e-6)
        conf_cv = float(np.exp(-float(rmse_cv) / max(traj_scale, 1e-6)) * np.exp(-2.0 * miss_ratio))
        conf_cv = float(np.clip(conf_cv, 0.0, 1.0))

        prior_list = [p_quad, p_lin, p_cv]
        valid_mask = (frame_ids < s) | (frame_ids > e)
        valid_xy = np.stack([x[valid_mask], y[valid_mask]], axis=-1).astype(np.float32)
        traj_scale = float(np.std(valid_xy) + 1e-6)

        left_idx = np.where(frame_ids < s)[0]
        right_idx = np.where(frame_ids > e)[0]
        if len(left_idx) >= 2 and len(right_idx) >= 2:
            li0, li1 = int(left_idx[-2]), int(left_idx[-1])
            ri0, ri1 = int(right_idx[0]), int(right_idx[1])
            lvx, lvy = _local_velocity(frame_ids, x, y, li0, li1)
            rvx, rvy = _local_velocity(frame_ids, x, y, ri0, ri1)
            left_dt = float(s - frame_ids[li1])
            right_dt = float(frame_ids[ri0] - e)
            left_expect = np.array([x[li1] + lvx * left_dt, y[li1] + lvy * left_dt], dtype=np.float32)
            right_expect = np.array([x[ri0] - rvx * right_dt, y[ri0] - rvy * right_dt], dtype=np.float32)
        else:
            left_expect = prior_list[0][0]
            right_expect = prior_list[0][-1]

        stacked_priors = np.stack(prior_list, axis=1).astype(np.float32)
        median_prior = np.median(stacked_priors, axis=1)
        boundary_scores = []
        prior_disagreements = []
        for p in prior_list:
            edge_err = np.linalg.norm(p[0] - left_expect) + np.linalg.norm(p[-1] - right_expect)
            boundary_scores.append(float(np.exp(-edge_err / max(2.0 * traj_scale, 1e-6))))
            prior_disagreements.append(float(np.mean(np.linalg.norm(p - median_prior, axis=-1))))

        disagreement_scores = np.exp(-np.asarray(prior_disagreements, dtype=np.float32) / max(traj_scale, 1e-6))
        boundary_scores_np = np.asarray(boundary_scores, dtype=np.float32)

        conf_arr = np.array([conf_quad, conf_lin, conf_cv], dtype=np.float32)
        conf_arr = conf_arr * np.clip(boundary_scores_np, 1e-4, 1.0) * np.clip(disagreement_scores, 1e-4, 1.0)
        conf_arr = np.clip(conf_arr, 1e-4, 1.0)
        conf_weights = conf_arr / np.sum(conf_arr)

        priors = stacked_priors  # [T, K=3, 2]
        prior_mix = np.sum(priors * conf_weights.reshape(1, -1, 1), axis=1).astype(np.float32)

        prior_mse = float(np.mean((prior_mix - gt_xy) ** 2))
        prior_conf = float(np.max(conf_weights))

        event_labels, event_weights = _build_event_targets(gt_xy)
        event_density = float(np.mean(event_labels))
        prior_disagreement = float(np.mean(prior_disagreements) / max(traj_scale, 1e-6))
        difficulty_score = float(0.55 * prior_mse / max(traj_scale**2, 1e-6) + 0.25 * event_density + 0.20 * miss_ratio)

        # 特征顺序:时间、融合/多先验坐标、运动学量、先验置信度、边界质量、分歧度、遮挡比例。
        t = np.arange(s, e + 1, dtype=np.float32)
        t_norm = (t - t.mean()) / (t.std() + 1e-6)

        mix_x = prior_mix[:, 0]
        mix_y = prior_mix[:, 1]
        vx = np.gradient(mix_x) * fps
        vy = np.gradient(mix_y) * fps
        ax = np.gradient(vx) * fps
        ay = np.gradient(vy) * fps

        feats = np.stack(
            [
                t_norm,
                mix_x,
                mix_y,
                p_quad[:, 0],
                p_quad[:, 1],
                p_lin[:, 0],
                p_lin[:, 1],
                p_cv[:, 0],
                p_cv[:, 1],
                vx,
                vy,
                ax,
                ay,
                np.full_like(t_norm, conf_quad),
                np.full_like(t_norm, conf_lin),
                np.full_like(t_norm, conf_cv),
                np.full_like(t_norm, boundary_scores_np[0]),
                np.full_like(t_norm, boundary_scores_np[1]),
                np.full_like(t_norm, boundary_scores_np[2]),
                np.full_like(t_norm, prior_disagreement),
                np.full_like(t_norm, miss_ratio),
            ],
            axis=-1,
        ).astype(np.float32)

        samples.append(
            Sample(
                feats=feats,
                prior=prior_mix,
                priors=priors,
                prior_confidences=conf_weights.astype(np.float32),
                target=gt_xy,
                event_labels=event_labels,
                event_weights=event_weights,
                video_id=int(vid),
                prior_mse=prior_mse,
                prior_confidence=prior_conf,
                missing_ratio=miss_ratio,
                event_density=event_density,
                difficulty_score=difficulty_score,
            )
        )
    return samples


class TrajectoryResidualDataset(Dataset):
    def __init__(
        self,
        samples: Sequence[Sample],
        augment_prob: float = 0.0,
        max_occlusion_ratio: float = 0.35,
        noise_std_ratio: float = 0.01,
    ):
        self.samples = list(samples)
        self.augment_prob = float(max(0.0, min(1.0, augment_prob)))
        self.max_occlusion_ratio = float(max(0.1, min(0.8, max_occlusion_ratio)))
        self.noise_std_ratio = float(max(0.0, noise_std_ratio))

    def __len__(self) -> int:
        return len(self.samples)

    def _apply_pseudo_occlusion(self, feats: np.ndarray) -> np.ndarray:
        out = feats.copy()
        t = out.shape[0]
        if t < 4:
            return out

        max_span = max(2, int(round(t * self.max_occlusion_ratio)))
        span = int(np.random.randint(2, max_span + 1))
        start = int(np.random.randint(0, max(1, t - span + 1)))
        end = min(t, start + span)

        # 动态特征:融合/多先验坐标 + 运动学量
        dyn_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

        if start > 0:
            fill = out[start - 1, dyn_idx]
        elif end < t:
            fill = out[end, dyn_idx]
        else:
            fill = np.zeros((len(dyn_idx),), dtype=out.dtype)
        out[start:end, dyn_idx] = fill
        return out

    def _apply_feature_noise(self, feats: np.ndarray) -> np.ndarray:
        if self.noise_std_ratio <= 0.0:
            return feats
        out = feats.copy()

        # 先验坐标与融合坐标
        coord = out[:, 1:9]
        # 运动学特征
        kin = out[:, 9:13]

        coord_scale = float(np.std(coord) + 1e-6) * self.noise_std_ratio
        kin_scale = float(np.std(kin) + 1e-6) * self.noise_std_ratio

        out[:, 1:9] = coord + np.random.normal(0.0, coord_scale, size=coord.shape).astype(np.float32)
        out[:, 9:13] = kin + np.random.normal(0.0, kin_scale, size=kin.shape).astype(np.float32)
        return out

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        feats = s.feats
        if self.augment_prob > 0 and np.random.rand() < self.augment_prob:
            feats = self._apply_pseudo_occlusion(feats)
            feats = self._apply_feature_noise(feats)

        x = torch.from_numpy(feats.astype(np.float32, copy=False))  # [T, C]
        prior = torch.from_numpy(s.prior.astype(np.float32, copy=False))  # [T, 2]
        priors = torch.from_numpy(s.priors.astype(np.float32, copy=False))  # [T, K, 2]
        prior_conf = torch.from_numpy(s.prior_confidences.astype(np.float32, copy=False))  # [K]
        target = torch.from_numpy(s.target.astype(np.float32, copy=False))  # [T, 2]
        event_labels = torch.from_numpy(s.event_labels.astype(np.float32, copy=False))  # [T, 1]
        event_weights = torch.from_numpy(s.event_weights.astype(np.float32, copy=False))  # [T, 1]
        residual = target - prior
        return x, prior, residual, target, s.video_id, priors, prior_conf, event_labels, event_weights