import os
import shutil
from typing import List, Dict
import numpy as np
import tensorflow as tf
from annc.optimize.graph import Graph, CustomAttr
from annc.optimize.op_type import OpType
from annc.optimize.data_pack import DataFormat, MatrixTiling, FormageChange


class LayoutMatmulRewriter:
    __vector_len__ = 12

    def __init__(self, model_path: str):
        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
        self.model_path = model_path
        self.sess = tf.compat.v1.Session(graph=tf.Graph())
        with self.sess.graph.as_default(), self.sess.as_default():
            tf.compat.v1.saved_model.loader.load(self.sess, ["serve"], model_path)
            self.vars = self.get_variables()

    def __call__(self, graph: Graph):
        with self.sess.graph.as_default(), self.sess.as_default():
            for node in graph.nodes:
                if node.type not in (OpType.MatMul.value,
                                     OpType.BatchMatMul.value):
                    continue

                lhs = self.get_variable_operand(node.operands[0][0])
                if lhs and self.check_require_folding(node, 'lhs_format'):
                    if self.check_data_size(lhs):
                        continue
                    self.pack_variable(lhs, DataFormat.mk, DataFormat.k4m4)
                    node.attrs.append(CustomAttr('lhs_format', 's: "k4m4"'))

                rhs = self.get_variable_operand(node.operands[1][0])
                if rhs and self.check_require_folding(node, 'rhs_format'):
                    if self.check_data_size(rhs):
                        continue
                    self.pack_variable(rhs, DataFormat.kn, DataFormat.kn4)
                    node.attrs.append(CustomAttr('rhs_format', 's: "kn4"'))

    def save(self, output_path: str):
        os.makedirs(output_path, exist_ok=True)

        src_pb = os.path.join(self.model_path, 'saved_model.pb')
        dst_pb = os.path.join(output_path, 'saved_model.pb')
        if os.path.exists(src_pb):
            shutil.copy(src_pb, dst_pb)

        var_path = os.path.join(output_path, 'variables/variables')
        var_dir = os.path.dirname(var_path)
        if os.path.exists(var_dir):
            shutil.rmtree(var_dir)
        os.makedirs(var_dir, exist_ok=True)
        with self.sess.graph.as_default(), self.sess.as_default():
            saver = tf.compat.v1.train.Saver()
            saver.save(self.sess, var_path)
        self.sess.close()

    def check_data_size(self, name: str):
        data: np.ndarray = self.vars[name].eval()
        return (data.shape[0] < self.__vector_len__) or (data.shape[1] < self.__vector_len__)

    def get_variable_operand(self, node) -> str:
        valid_op_types = frozenset(
            [OpType.Identity.value, OpType.ReadVariableOp.value])
        valid_var_types = frozenset([
            OpType.VarHandleOp.value, OpType.VariableV2.value,
            OpType.Variable.value
        ])
        if node.type in valid_op_types:
            return self.get_variable_operand(node.operands[0][0])
        if node.type in valid_var_types:
            return node.name

    def pack_variable(self, name: str, src_format: DataFormat,
                      dst_format: DataFormat):
        data: np.ndarray = self.vars[name].eval()
        packed_data = FormageChange(self.__vector_len__).run(
            data,
            src_format=src_format,
            dst_format=dst_format,
            tiling_info=self.generate_tiling(self.vars[name]))
        self.sess.run(tf.compat.v1.assign(self.vars[name], packed_data))

    def generate_tiling(self, value: tf.Variable) -> List[MatrixTiling]:
        return [MatrixTiling(0, 352), MatrixTiling(1, 128)]

    def check_require_folding(self, node, key: str):
        for attr in node.attrs:
            if attr.key == key:
                return False
        return True

    def get_variables(self) -> Dict[str, tf.Variable]:
        all_vars = tf.compat.v1.global_variables()
        return {v.name.split(':')[0]: v for v in all_vars}