// Copyright (c) 2025, Huawei Technologies Co., Ltd.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use ahash::{HashMap, HashSet};
use serde::{Serialize, Serializer, ser::SerializeStruct, Deserialize};
use smartstring::alias::String;
use crate::processors::geir::geir::AttrDef;
use self::AttrValue::*;

#[derive(Debug, Serialize, Clone)]
pub struct Model {
    pub name: String,
    pub nodes: HashMap<String, Node>,
    pub edges: Vec<(String, String)>,
    pub parameters: HashMap<String, String>,
    pub subgraphes: HashMap<String, Model>,
    // 当 subgraph 在主图中显示时,将这种 subgraph 称为 Cluster
    pub clusters: HashMap<String, Model>, // 图节点的嵌套关系 parent -> Set<node>(当子图选择平铺时使用)
}

impl Model {
    const MAX_DEPTH: usize = 5;
    pub fn populate_node_connections(&mut self) {
        for (src, tgt) in &self.edges {
            if let Some(src_node) = self.nodes.get_mut(src) {
                src_node.output.push(tgt.clone());
            }
            if let Some(tgt_node) = self.nodes.get_mut(tgt) {
                tgt_node.input.push(src.clone());
            }
        }
    }

    pub fn populate_node_outputs(&mut self) {
        for (src, tgt) in &self.edges {
            if let Some(src_node) = self.nodes.get_mut(src) {
                src_node.output.push(tgt.clone());
            }
        }
    }

    pub fn populate_edges(&mut self) {
        let mut edges = Vec::new();

        for (node_name, node) in &self.nodes {
            for input_name in &node.input {
                // 如果输入节点存在于 nodes 中,则添加一条边
                if self.nodes.contains_key(input_name) {
                    edges.push((input_name.clone(), node_name.clone()));
                }
            }
        }

        self.edges = edges;
    }

    pub fn find_node_by_key_mut(&mut self, key: &str) -> Option<&mut Node> {
        self.find_node_by_key_recursive_mut(key, 0)
    }

    fn find_node_by_key_recursive_mut(&mut self, key: &str, depth: usize) -> Option<&mut Node> {
        if depth > Self::MAX_DEPTH {
            return None;
        }
        let v = self.nodes.get_mut(key);
        if v.is_some() {
            return v;
        }
        for model in self.clusters.values_mut() {
            let res = model.find_node_by_key_recursive_mut(key, depth + 1);
            if res.is_some() {
                return res;
            }
        }
        for model in self.subgraphes.values_mut() {
            let res = model.find_node_by_key_recursive_mut(key, depth + 1);
            if res.is_some() {
                return res;
            }
        }
        None
    }
}

#[derive(Debug, Serialize, Clone)]
pub struct Tensor {
    #[serde(skip_serializing_if = "String::is_empty")]
    pub name: String,
    pub category: &'static str,
    pub repr: String,
    #[serde(skip_serializing_if = "String::is_empty")]
    pub location: String,
}

#[allow(non_snake_case)]
#[derive(Debug, Serialize, Clone)]
pub struct Node {
    pub name: String,
    pub opType: String,
    // #[serde(skip_serializing_if = "Vec::is_empty")]
    pub input: Vec<String>,
    // #[serde(skip_serializing_if = "Vec::is_empty")]
    pub output: Vec<String>,
    // #[serde(skip_serializing_if = "HashMap::is_empty")]
    pub attributes: HashMap<String, AttrValue>,
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub tensors: Vec<Tensor>,
    pub dynamic: bool,
}

#[derive(Debug, Clone)]
pub enum AttrValue {
    StringLike(String),
    StringLikeArray(Vec<String>),
    TensorVal(String),
    TensorVals(Vec<String>),
    TensorsTuple(Vec<Vec<String>>),
}

impl Serialize for AttrValue {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut state = serializer.serialize_struct("AttrValue", 2)?;
        match self {
            StringLike(s) => {
                state.serialize_field("type", "string-like")?;
                state.serialize_field("value", s)?;
            }
            StringLikeArray(s) => {
                state.serialize_field("type", "string-like-array")?;
                state.serialize_field("value", s)?;
            }
            TensorVal(s) => {
                state.serialize_field("type", "tensor-val")?;
                state.serialize_field("value", s)?;
            }
            TensorVals(v) => {
                state.serialize_field("type", "tensor-vals")?;
                state.serialize_field("value", v)?;
            }
            TensorsTuple(vv) => {
                state.serialize_field("type", "tensors-tuple")?;
                state.serialize_field("value", vv)?;
            }
        }
        state.end()
    }
}

#[derive(Debug, Serialize)]
pub struct Edge {
    pub source: String,
    pub target: String,
}

impl Edge {
    pub fn new(s: String, t: String) -> Self {
        Self { source: s, target: t }
    }
}