import os
import json
import math
import argparse
from pathlib import Path

import pandas as pd
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from vbench import VBench
from vbench.distributed import dist_init, print0


def parse_args():
    parser = argparse.ArgumentParser(description='VBench', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        "--full_json_dir",
        type=str,
        default="./vbench/VBench_full_info.json",
        help="path to save the json file that contains the prompt and dimension information",
    )
    parser.add_argument(
        "--load_ckpt_from_local",
        type=bool,
        required=False,
        help="whether load checkpoints from local default paths (assuming you have downloaded the checkpoints locally",
    )
    parser.add_argument(
        "--read_frame",
        type=bool,
        required=False,
        help="whether directly read frames, or directly read videos",
    )
    parser.add_argument(
        "--mode",
        choices=['custom_input', 'vbench_standard', 'vbench_category'],
        default='custom_input',
        help="""This flags determine the mode of evaluations, choose one of the following:
        1. "custom_input": receive input prompt from either --prompt/--prompt_file flags or the filename
        2. "vbench_standard": evaluate on standard prompt suite of VBench
        3. "vbench_category": evaluate on specific category
        """,
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="None",
        help="""Specify the input prompt
        If not specified, filenames will be used as input prompts
        * Mutually exclusive to --prompt_file.
        ** This option must be used with --mode=custom_input flag
        """
    )
    parser.add_argument(
        "--prompt_file",
        type=str,
        required=False,
        help="""Specify the path of the file that contains prompt lists
        If not specified, filenames will be used as input prompts
        * Mutually exclusive to --prompt.
        ** This option must be used with --mode=custom_input flag
        """
    )
    parser.add_argument(
        "--videos_path",
        type=str,
        required=True,
        help="Videos path to be evaluated.",
    )
    parser.add_argument(
        "--num_inference_videos_per_sample",
        type=int,
        default=3,
        help="Number of videos generated by inference for each sample.",
    )
    parser.add_argument(
        "--category",
        type=str,
        required=False,
        help="""This is for mode=='vbench_category'
        The category to evaluate on, usage: --category=animal.
        """,
    )

    ## for dimension specific params ###
    parser.add_argument(
        "--imaging_quality_preprocessing_mode",
        type=str,
        required=False,
        default='longer',
        help="""This is for setting preprocessing in imaging_quality
        1. 'shorter': if the shorter side is more than 512, the image is resized so that the shorter side is 512.
        2. 'longer': if the longer side is more than 512, the image is resized so that the longer side is 512.
        3. 'shorter_centercrop': if the shorter side is more than 512, the image is resized so that the shorter side is 512.
        Then the center 512 x 512 after resized is used for evaluation.
        4. 'None': no preprocessing
        """,
    )
    args = parser.parse_args()
    return args


def load_prompts(prompt_file):
    if os.path.exists(prompt_file):
        with open(prompt_file, "r") as f:
            lines = f.readlines()
            prompts = [line.strip() for line in lines]
        return prompts
    else:
        raise Exception(f"prompt_file {prompt_file} invalid")


def save_json(data, json_path):
    df = pd.DataFrame(data)
    df.to_json(json_path, orient='records', lines=True)
    print("The data is successfully formatted and the file is saved as ", json_path)


def data_format(args, video_list, weighted_score):
    videos_path = args.videos_path
    num_inference_videos_per_sample = args.num_inference_videos_per_sample
    prompts = load_prompts(args.prompt_file)
    scores_dict = dict(zip(video_list, weighted_score))

    res = []
    prompts_nums = len(prompts)
    for i in range(prompts_nums):
        sample = {}
        max_score = -999.0
        min_score = 999.0

        for j in range(num_inference_videos_per_sample):
            # Calculates the position of the current element in the repeating sequence.
            index = i * num_inference_videos_per_sample + j
            video = os.path.join(videos_path, f"video_{index}.mp4")
            if video not in scores_dict:
                continue
            if scores_dict[video] > max_score:
                max_score = scores_dict[video]
                max_score_file = f"./video_{index}.mp4"
            if scores_dict[video] < min_score:
                min_score = scores_dict[video]
                min_score_file = f"./video_{index}.mp4"

        sample['file'] = max_score_file
        sample['file_rejected'] = min_score_file
        sample['captions'] = prompts[i]
        sample['score'] = max_score
        sample['score_rejected'] = min_score

        res.append(sample)

    base_path = Path(videos_path).parent
    save_json(res, os.path.join(base_path, 'data.jsonl'))


def video_evaluator(args, eval_result_path):
    dist_init()
    device = torch.device("npu")
    m_VBench = VBench(device, args.full_json_dir, eval_result_path)

    print0(f'start evaluation')
    kwargs = {}
    prompt = []

    if args.category != "":
        kwargs['category'] = args.category

    kwargs['imaging_quality_preprocessing_mode'] = args.imaging_quality_preprocessing_mode

    m_VBench.evaluate(
        videos_path=args.videos_path,
        name=args.name,
        prompt_list=prompt,  # pass in [] to read prompt from filename
        dimension_list=args.dimension,
        local=args.load_ckpt_from_local,
        read_frame=args.read_frame,
        mode=args.mode,
        **kwargs
    )
    print0('done')


def read_json_file(filename):
    """read json file"""
    with open(filename, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data


def extract_scores(data, scores_dict):
    """extract scores"""
    for key, value in data.items():
        for item in value[1]:
            if 'video_results' in item and isinstance(item['video_results'], float) and 'video_path' in item:
                if item['video_path'] not in scores_dict:
                    scores_dict.update({item['video_path']: dict()})
                scores_dict[item['video_path']][key] = item['video_results']

    return scores_dict


def compute_histogram(all_scores, dimension_list, dimension_weight, bin_width=0.01, normalize=False):
    """
    Compute Histogram (Dictionary)

    :return
        dict: histogram dictionary,The format is as follows: { "0-9": 2, "10-19": 5, ... }
    """
    if not all_scores:
        return {}

    # Weighted average of multi-dimensional scoring
    video_list = all_scores.keys()
    weighted_score = torch.zeros(len(video_list))
    weight_sum = 0.0
    for dimension in dimension_list:
        scores = [all_scores[video][dimension] for video in video_list]
        min_score = min(scores)
        max_score = max(scores)
        scores = torch.Tensor(scores)
        if normalize:  # Single dimension normalization to [0, 1]
            scores = (scores - min_score) / (max_score - min_score)
            bin_width = 0.1

        weighted_score = weighted_score + scores * dimension_weight[dimension]
        weight_sum += dimension_weight[dimension]

    weighted_score = (weighted_score / weight_sum).tolist()
    min_score = min(weighted_score)
    max_score = max(weighted_score)

    # Calculate the appropriate range (ensure all data is covered)
    min_bin = math.floor(min_score / bin_width)
    max_bin = math.ceil(max_score / bin_width)
    bin_size = math.ceil(max_bin - min_bin)

    histogram = {}
    for lower in range(min_bin, max_bin):
        key = f"{lower * bin_width}-{lower * bin_width + bin_width}"
        histogram[key] = 0

    total_num = 0
    # Count the number of frequencies in each range.
    for score in weighted_score:
        bin_index = math.floor(score / bin_width)
        if bin_index == max_bin:
            bin_index = bin_index - 1
        lower = bin_index * bin_width
        upper = lower + bin_width
        key = f"{lower}-{upper}"
        histogram[key] += 1
        total_num += 1

    # Statistical Highest Probability Grouping
    max_num = 0
    for lower in range(min_bin, max_bin):
        key = f"{lower * bin_width}-{lower * bin_width + bin_width}"
        if histogram[key] > max_num:
            max_num = histogram[key]

    histogram.update(
        {'total_num': total_num, 'max_num': max_num, 'min_bin': min_bin, 'max_bin': max_bin, 'bin_width': bin_width,
         'bin_size': bin_size})

    return histogram, video_list, weighted_score


def generate_score_histogram(eval_result_files, dimension_list, dimension_weight, output_file):
    scores = dict()
    if isinstance(eval_result_files, str):
        eval_result_files = [eval_result_files]
    for json_file in eval_result_files:
        data = read_json_file(json_file)
        scores = extract_scores(data, scores)

    # Calculate the histogram (bin_size=10 by default)
    histogram, video_list, weighted_score = compute_histogram(scores, dimension_list, dimension_weight,
                                                              bin_width=0.01, normalize=True)

    # Print histograms
    print("Histogram (dictionary):")
    for bin_range, count in sorted(histogram.items()):
        print(f"{bin_range}: {count}")
    # Save as a .json file
    with open(output_file, 'w', encoding='utf-8') as file:
        json.dump(histogram, file)
    print(f"Histogram saved to {output_file}")

    return video_list, weighted_score


if __name__ == '__main__':
    args = parse_args()
    print0(f'args: {args}')
    name = 'StepVideoDPO'
    dimension_list = ['subject_consistency', 'background_consistency', 'motion_smoothness', 'aesthetic_quality',
                      'imaging_quality']
    # 1. evaluate the videos by Vbench
    videos_path = args.videos_path
    args.name = name
    args.dimension = dimension_list
    output_path = Path(videos_path).parent
    eval_result_path = os.path.join(output_path, 'evaluation_results')

    if not os.path.exists(eval_result_path):
        os.makedirs(eval_result_path)

    video_evaluator(args, eval_result_path)
    eval_result_files = os.path.join(eval_result_path,
                                     f'{name}_eval_results.json')  # File path for storing video scores.

    # 2. generate score histogram of video eval results
    # weight from VideoDPO
    dimension_weight = {'subject_consistency': 4.0,
                        'background_consistency': 4.0,
                        'temporal_flickering': 4.0,
                        'motion_smoothness': 4.0,
                        'dynamic_degree': 4.0,
                        'aesthetic_quality': 4.0,
                        'imaging_quality': 4.0,
                        'object_class': 4.0,
                        'multiple_objects': 4.0,
                        'human_action': 4.0,
                        'color': 4.0,
                        'spatial_relationship': 4.0,
                        'scene': 4.0,
                        'temporal_style': 4.0,
                        'appearance_style': 4.0,
                        'overall_consistency': 1.0}

    histogram_output_file = os.path.join(output_path, 'video_score_histogram.json')
    video_list, weighted_score = generate_score_histogram(eval_result_files, dimension_list, dimension_weight,
                                                          histogram_output_file)

    # 3. add video scores into data.jsonl
    data_format(args, video_list, weighted_score)