use ahash::{HashMap, HashSet};
use serde::{Serialize, Serializer, ser::SerializeStruct, Deserialize};
use smartstring::alias::String;
use crate::processors::geir::geir::AttrDef;
use self::AttrValue::*;
#[derive(Debug, Serialize, Clone)]
pub struct Model {
pub name: String,
pub nodes: HashMap<String, Node>,
pub edges: Vec<(String, String)>,
pub parameters: HashMap<String, String>,
pub subgraphes: HashMap<String, Model>,
pub clusters: HashMap<String, Model>,
}
impl Model {
const MAX_DEPTH: usize = 5;
pub fn populate_node_connections(&mut self) {
for (src, tgt) in &self.edges {
if let Some(src_node) = self.nodes.get_mut(src) {
src_node.output.push(tgt.clone());
}
if let Some(tgt_node) = self.nodes.get_mut(tgt) {
tgt_node.input.push(src.clone());
}
}
}
pub fn populate_node_outputs(&mut self) {
for (src, tgt) in &self.edges {
if let Some(src_node) = self.nodes.get_mut(src) {
src_node.output.push(tgt.clone());
}
}
}
pub fn populate_edges(&mut self) {
let mut edges = Vec::new();
for (node_name, node) in &self.nodes {
for input_name in &node.input {
if self.nodes.contains_key(input_name) {
edges.push((input_name.clone(), node_name.clone()));
}
}
}
self.edges = edges;
}
pub fn find_node_by_key_mut(&mut self, key: &str) -> Option<&mut Node> {
self.find_node_by_key_recursive_mut(key, 0)
}
fn find_node_by_key_recursive_mut(&mut self, key: &str, depth: usize) -> Option<&mut Node> {
if depth > Self::MAX_DEPTH {
return None;
}
let v = self.nodes.get_mut(key);
if v.is_some() {
return v;
}
for model in self.clusters.values_mut() {
let res = model.find_node_by_key_recursive_mut(key, depth + 1);
if res.is_some() {
return res;
}
}
for model in self.subgraphes.values_mut() {
let res = model.find_node_by_key_recursive_mut(key, depth + 1);
if res.is_some() {
return res;
}
}
None
}
}
#[derive(Debug, Serialize, Clone)]
pub struct Tensor {
#[serde(skip_serializing_if = "String::is_empty")]
pub name: String,
pub category: &'static str,
pub repr: String,
#[serde(skip_serializing_if = "String::is_empty")]
pub location: String,
}
#[allow(non_snake_case)]
#[derive(Debug, Serialize, Clone)]
pub struct Node {
pub name: String,
pub opType: String,
pub input: Vec<String>,
pub output: Vec<String>,
pub attributes: HashMap<String, AttrValue>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tensors: Vec<Tensor>,
pub dynamic: bool,
}
#[derive(Debug, Clone)]
pub enum AttrValue {
StringLike(String),
StringLikeArray(Vec<String>),
TensorVal(String),
TensorVals(Vec<String>),
TensorsTuple(Vec<Vec<String>>),
}
impl Serialize for AttrValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("AttrValue", 2)?;
match self {
StringLike(s) => {
state.serialize_field("type", "string-like")?;
state.serialize_field("value", s)?;
}
StringLikeArray(s) => {
state.serialize_field("type", "string-like-array")?;
state.serialize_field("value", s)?;
}
TensorVal(s) => {
state.serialize_field("type", "tensor-val")?;
state.serialize_field("value", s)?;
}
TensorVals(v) => {
state.serialize_field("type", "tensor-vals")?;
state.serialize_field("value", v)?;
}
TensorsTuple(vv) => {
state.serialize_field("type", "tensors-tuple")?;
state.serialize_field("value", vv)?;
}
}
state.end()
}
}
#[derive(Debug, Serialize)]
pub struct Edge {
pub source: String,
pub target: String,
}
impl Edge {
pub fn new(s: String, t: String) -> Self {
Self { source: s, target: t }
}
}