#!/usr/bin/env python3

import argparse
import json
from pathlib import Path

import numpy as np
from PIL import Image
from ais_bench.infer.interface import InferSession
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval


COCO80_TO_91 = [
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21,
    22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
    43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
    62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84,
    85, 86, 87, 88, 89, 90,
]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True, help="om model path")
    parser.add_argument("--image-dir", required=True, help="COCO val2017 dir")
    parser.add_argument("--ann-file", required=True, help="instances_val2017.json")
    parser.add_argument("--size", type=int, default=640, help="input size")
    parser.add_argument("--device", type=int, default=0, help="device id")
    parser.add_argument("--output-json", default="predictions.json", help="coco result json")
    parser.add_argument(
        "--zero-based-label",
        action="store_true",
        help="set this if labels are 0~79 instead of COCO category_id",
    )
    return parser.parse_args()


def preprocess(image, size):
    image = image.resize((size, size), Image.BILINEAR)
    image = np.asarray(image, dtype=np.float32) / 255.0
    image = image.transpose(2, 0, 1)[None]
    return np.ascontiguousarray(image)


def xyxy_to_xywh(box):
    x1, y1, x2, y2 = [float(v) for v in box]
    return [x1, y1, max(0.0, x2 - x1), max(0.0, y2 - y1)]


def map_label(label, zero_based):
    label = int(label)
    if zero_based:
        return COCO80_TO_91[label]
    return label


def main():
    args = parse_args()

    coco = COCO(args.ann_file)
    session = InferSession(device_id=args.device, model_path=args.model)
    results = []

    images = coco.dataset["images"]
    total = len(images)

    for i, info in enumerate(images, 1):
        image_path = Path(args.image_dir) / info["file_name"]
        image = Image.open(image_path).convert("RGB")
        width, height = image.size

        inputs = [
            preprocess(image, args.size),
            np.array([[width, height]], dtype=np.int64),
        ]
        labels, boxes, scores = session.infer(inputs)

        labels = np.asarray(labels)[0].reshape(-1)
        boxes = np.asarray(boxes)[0].reshape(-1, 4)
        scores = np.asarray(scores)[0].reshape(-1)

        for label, box, score in zip(labels, boxes, scores):
            results.append(
                {
                    "image_id": int(info["id"]),
                    "category_id": map_label(label, args.zero_based_label),
                    "bbox": xyxy_to_xywh(box),
                    "score": float(score),
                }
            )

        if i % 100 == 0 or i == total:
            print(f"[{i}/{total}] processed", flush=True)

    with open(args.output_json, "w", encoding="utf-8") as f:
        json.dump(results, f)

    coco_dt = coco.loadRes(args.output_json)
    evaluator = COCOeval(coco, coco_dt, "bbox")
    evaluator.evaluate()
    evaluator.accumulate()
    evaluator.summarize()

    print(f"APval={evaluator.stats[0] * 100:.1f}")
    print(f"AP50val={evaluator.stats[1] * 100:.1f}")


if __name__ == "__main__":
    main()