import argparse
import json
import multiprocessing as mp
import os

import cv2
import numpy as np

import torch
from pytorch_msssim import ms_ssim

N_LINES = 25


def eval_script_uvdoc(gtdir, imdir, verbose=True):
    """
    Evaluate UVDoc image quality.
    """
    res = []

    for k in range(50):
        gt_path = os.path.join(gtdir, f"{k:05d}.png")
        if not os.path.exists(gt_path):
            if verbose:
                print(f"{gt_path} - Not file (ref/input)")
            continue

        pred_path = os.path.join(imdir, f"{k:05d}.png")
        if os.path.exists(pred_path):
            if verbose:
                print(f"Running {k:05d}.png ...", flush=True)
        
        rimg = cv2.imread(gt_path)
        ximg = cv2.imread(pred_path)
        rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2GRAY)
        ximg = cv2.cvtColor(ximg, cv2.COLOR_BGR2GRAY)

        # reshape
        rh, rw = rimg.shape
        resize_factory = np.sqrt(598400 / (rh * rw))
        rimg = cv2.resize(rimg, (0, 0), fx=resize_factory, fy=resize_factory, interpolation=cv2.INTER_CUBIC)
        rh, rw = rimg.shape
        ximg = cv2.resize(ximg, (rw, rh), interpolation=cv2.INTER_CUBIC)

        # 预处理,将 [H, W] 转换为 [1, 1, H, W],数值归一化到 [0,1]
        rimg_tensor = torch.from_numpy(rimg).float() / 255.0
        ximg_tensor = torch.from_numpy(ximg).float() / 255.0
        rimg_tensor = rimg_tensor.unsqueeze(0).unsqueeze(0)
        ximg_tensor = ximg_tensor.unsqueeze(0).unsqueeze(0)

        # MS-SSIM 计算
        ms_ssim_val = ms_ssim(rimg_tensor, ximg_tensor, data_range=1.0, size_average=True)
        res.append(ms_ssim_val)
    res = np.array(res)
    avg = np.mean(res)
    return avg


def visual_metrics_process(queue, uvdoc_path, preds_path, verbose):
    """
    Subprocess function that computes visual metrics (MS-SSIM, LD, and AD) based on a matlab script.
    """
    mean_ms = eval_script_uvdoc(uvdoc_path, preds_path, verbose)
    queue.put(dict(ms=mean_ms))


def ocr_process(queue, uvdoc_path, preds_path):
    """
    Subprocess function that computes OCR metrics (CER and ED).
    """
    from eval.ocr_eval.ocr_eval import OCR_eval_UVDoc

    CERmean, EDmean, OCR_dict_results = OCR_eval_UVDoc(uvdoc_path, preds_path)
    with open(os.path.join(preds_path, "ocr_res.json"), "w") as f:
        json.dump(OCR_dict_results, f)
    queue.put(dict(cer=CERmean, ed=EDmean))


def new_line_metric_process(queue, uvdoc_path, preds_path, n_lines):
    """
    Subprocess function that computes the new line metrics on the UVDoc benchmark.
    """
    from uvdocBenchmark_metric import compute_line_metric

    hor_metric, ver_metric = compute_line_metric(uvdoc_path, preds_path, n_lines)
    queue.put(dict(hor_line=hor_metric, ver_line=ver_metric))


def compute_metrics(uvdoc_path, pred_path, pred_type, verbose=False):
    """
    Compute and save all metrics.
    """
    if not pred_path.endswith("/"):
        pred_path += "/"
    q = mp.Queue()

    # Create process to compute MS-SSIM, LD, AD
    p1 = mp.Process(
        target=visual_metrics_process,
        args=(q, os.path.join(uvdoc_path, "texture_sample"), os.path.join(pred_path, pred_type), verbose),
    )
    p1.start()

    # Create process to compute new line metrics
    p2 = mp.Process(
        target=new_line_metric_process,
        args=(q, uvdoc_path, os.path.join(pred_path, "bm"), N_LINES),
    )
    p2.start()

    # Create process to compute OCR metrics
    p3 = mp.Process(
        target=ocr_process, args=(q, os.path.join(uvdoc_path, "texture_sample"), os.path.join(pred_path, pred_type))
    )
    p3.start()

    p1.join()
    p2.join()
    p3.join()

    # Get results
    res = {}
    for _ in range(q.qsize()):
        ret = q.get()
        for k, v in ret.items():
            res[k] = v

    # Print and saves results
    print("--- Results ---")
    print(f"  Mean MS-SSIM      : {res.get('ms', 'Not exist')}")
    print(f"  Mean CER          : {res.get('cer', 'Not exist')}")
    print(f"  Mean ED           : {res.get('ed', 'Not exist')}")
    print(f"  Hor Line          : {res.get('hor_line', 'Not exist')}")
    print(f"  Ver Line          : {res.get('ver_line', 'Not exist')}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--uvdoc-path", type=str, default="./data/UVDoc_benchmark/", help="Path to the uvdoc benchmark dataset"
    )
    parser.add_argument("--pred-path", type=str, help="Path to the UVDoc benchmark predictions. Need to be absolute.")
    parser.add_argument(
        "--pred-type",
        type=str,
        default="uwp_texture",
        choices=["uwp_texture", "uwp_img"],
        help="Which type of prediction to compare. Either the unwarped textures or the unwarped litted images.",
    )
    parser.add_argument("-v", "--verbose", action="store_true")
    args = parser.parse_args()

    compute_metrics(
        uvdoc_path=os.path.abspath(args.uvdoc_path),
        pred_path=os.path.abspath(args.pred_path),
        pred_type=args.pred_type,
        verbose=args.verbose,
    )