import time
from functools import wraps
import cv2
import numpy as np
import torch
import torchvision.transforms as t
from PIL import Image, ImageDraw
import acl

ACL_ERROR_NONE = 0
ACL_MEM_MALLOC_HUGE_FIRST = 0
ACL_MEMCPY_HOST_TO_DEVICE = 1
ACL_MEMCPY_DEVICE_TO_HOST = 2
NPY_BYTE = 1
ACL_FLOAT = 0
ACL_FLOAT16 = 1
ACL_INT8 = 2
ACL_INT32 = 3
ACL_INT16 = 6
ACL_INT64 = 9
ACL_BOOL = 12

TYPE_MAP = {
    ACL_FLOAT: np.float32,
    ACL_FLOAT16: np.float16,
    ACL_INT8: np.int8,
    ACL_INT16: np.int16,
    ACL_INT32: np.int32,
    ACL_INT64: np.int64,
    ACL_BOOL: np.bool_
}


class AscendPredictor(object):
    def __init__(self, model_path, device_id=0):
        self.device_id = device_id
        self.input_data = []
        self.output_data = []
        self.result_numpy = {"ptr": [], "data": []}
        self._init_device()
        self._init_model(model_path)
        self._init_input_dataset()
        self._init_output_dataset()

    def _init_device(self):
        acl.init()
        acl.rt.set_device(self.device_id)
        self.context, _ = acl.rt.create_context(self.device_id)

    def _init_model(self, model_path=''):
        self.model_id, ret = acl.mdl.load_from_file(model_path)
        self.check_ret("acl.mdl.load_from_file", ret)
        self.model_desc = acl.mdl.create_desc()
        ret = acl.mdl.get_desc(self.model_desc, self.model_id)
        self.check_ret("acl.mdl.get_desc", ret)

    def _init_input_dataset(self):
        self.input_dataset = acl.mdl.create_dataset()
        self._input_size = acl.mdl.get_num_inputs(self.model_desc)
        for i in range(self._input_size):
            input_size = acl.mdl.get_input_size_by_index(self.model_desc, i)
            self._init_dataset(input_size, self.input_dataset, self.input_data)

    def _init_output_dataset(self):
        self.output_dataset = acl.mdl.create_dataset()
        self._output_size = acl.mdl.get_num_outputs(self.model_desc)
        for i in range(self._output_size):
            output_size = acl.mdl.get_output_size_by_index(self.model_desc, i)
            self._init_dataset(output_size, self.output_dataset, 
                               self.output_data)
            self._init_result_numpy(output_size, i)

    def _init_dataset(self, size, dataset, data_dict):
        buffer, ret = acl.rt.malloc(size, ACL_MEM_MALLOC_HUGE_FIRST)
        self.check_ret("acl.rt.malloc", ret)
        data = acl.create_data_buffer(buffer, size)
        _, ret = acl.mdl.add_dataset_buffer(dataset, data)
        self.check_ret("acl.mdl.add_dataset_buffer", ret)
        data_dict.append({"buffer": buffer, "size": size})

    def _init_result_numpy(self, size, index):
        dims = acl.mdl.get_output_dims(self.model_desc, index)
        np_shape = tuple(dims[0]["dims"])
        datatype = acl.mdl.get_output_data_type(self.model_desc, index)
        if datatype not in TYPE_MAP:
            raise Exception("unknown datatype %d" % datatype)
        np_type = TYPE_MAP[datatype]
        np_size = size // np.dtype(np_type).itemsize
        output_tensor = np.zeros(np_size, dtype=np_type).reshape(np_shape)
        if not output_tensor.flags["C_CONTIGUOUS"]:
            output_tensor = np.ascontiguousarray(output_tensor)
        tensor_ptr = output_tensor.ctypes.data
        self.result_numpy["ptr"].append(tensor_ptr)
        self.result_numpy["data"].append(output_tensor)

    @staticmethod
    def check_ret(message, ret):
        if ret != ACL_ERROR_NONE:
            raise Exception("{} failed ret ={}".format(message, ret))

    def infer(self, inputs):
        acl.rt.set_context(self.context)
        for i, input_data in enumerate(inputs):
            if not input_data.flags["C_CONTIGUOUS"]:
                input_data = np.ascontiguousarray(input_data)
            np_ptr = input_data.ctypes.data
            ret = acl.rt.memcpy(self.input_data[i]["buffer"],
                                self.input_data[i]["size"], np_ptr,
                                self.input_data[i]["size"],
                                ACL_MEMCPY_HOST_TO_DEVICE)
            self.check_ret("acl.rt.memcpy", ret)
        ret = acl.mdl.execute(self.model_id, self.input_dataset,
                              self.output_dataset)
        self.check_ret("acl.rt.execute", ret)
        for i, _item in enumerate(self.output_data):
            output_ptr = self.result_numpy["ptr"][i]
            ret = acl.rt.memcpy(output_ptr, self.output_data[i]["size"],
                                self.output_data[i]["buffer"],
                                self.output_data[i]["size"],
                                ACL_MEMCPY_DEVICE_TO_HOST)
            self.check_ret("acl.rt.memcpy", ret)
        return self.result_numpy["data"]

    def __del__(self):
        self._release_dataset(self.input_data, self.input_dataset)
        self._release_dataset(self.output_data, self.output_dataset)
        if self.model_id:
            acl.mdl.unload(self.model_id)
        if self.model_desc:
            acl.mdl.destroy_desc(self.model_desc)
        if self.context:
            acl.rt.destroy_context(self.context)
        acl.rt.reset_device(self.device_id)
        acl.finalize()

    @staticmethod
    def _release_dataset(data, dataset):
        while data:
            item = data.pop()
            acl.rt.free(item["buffer"])
        output_number = acl.mdl.get_dataset_num_buffers(dataset)
        for i in range(output_number):
            data_buf = acl.mdl.get_dataset_buffer(dataset, i)
            if data_buf:
                acl.destroy_data_buffer(data_buf)
        acl.mdl.destroy_dataset(dataset)


def time_logger(prefix=""):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            result = func(*args, **kwargs)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"{prefix} cost: {elapsed_time :.6f} s.")
            return result

        return wrapper

    return decorator


def resize_with_cv2(image, size, interpolation=cv2.INTER_BITS):
    h, w = image.shape[:2]
    ratio = min(size / w, size / h)
    new_width = int(w * ratio)
    new_height = int(h * ratio)
    image = cv2.resize(image, (new_width, new_height), interpolation=interpolation)
    new_image = np.zeros((size, size, 3), dtype=np.uint8)
    h_offset, w_offset = (size - new_height) // 2, (size - new_width) // 2
    new_image[h_offset: h_offset + new_height, w_offset: w_offset + new_width, :] = image
    return new_image, ratio, (size - new_width) // 2, (size - new_height) // 2


@time_logger(prefix="preprocess new")
def preprocess_new(im_np, size=640):
    resized_im_pil, ratio, pad_w, pad_h = resize_with_cv2(im_np, size)
    im_data = cv2.cvtColor(resized_im_pil, cv2.COLOR_BGR2RGB)
    im_data = np.asarray(im_data, dtype=np.float32) / 255.0
    im_data = np.expand_dims(im_data.transpose((2, 0, 1)), axis=0)
    orig_size = np.array([[640, 640]])
    return im_data, orig_size, ratio, pad_w, pad_h


@time_logger(prefix="inference: ")
def inference(sess, im_data, orig_size):
    output = sess.infer([im_data,
                         orig_size])
    return output


@time_logger(prefix="process image total: ")
def process_image(sess, input_path, conf_thres):
    st = time.time()
    im = cv2.imread(input_path)
    print(f"read cost {time.time() - st:.6f} s")
    im_data, orig_size, _, _, _ = preprocess_new(im)

    output = inference(sess, im_data, orig_size)
    labels, boxes, scores = output
    print(f"labels: {labels}, boxes: {boxes}, scores: {scores}")


def main(args):
    session = AscendPredictor(device_id=args.device,
                              model_path=args.om)
    print(f"Using device: {args.device}")
    input_path = args.input
    process_image(session, input_path, conf_thres=args.conf_thres)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--om",
                        "-m",
                        type=str,
                        required=True,
                        help="Path to the om model file.")
    parser.add_argument("--input",
                        "-i",
                        type=str,
                        required=True,
                        help="Path to the input image file."
                        )
    parser.add_argument("--device",
                        "-d",
                        type=int,
                        default=0,
                        help="Device ID."
                        )
    parser.add_argument("--conf_thres",
                        "-t",
                        type=float,
                        default=0.4,
                        help="conf threshold."
                        )
    args = parser.parse_args()
    main(args)