use prost::Message;

use crate::{backend::DType, proto::pbonnx};

#[allow(dead_code)]
enum ActFn {
    SiLU,
}
pub struct ExpertOnnxBuilder {
    pub intermediate_size: i64,
    pub hidden_size: i64,
    pub data_type: DType,
}

impl From<pbonnx::tensor_proto::DataType> for DType {
    fn from(value: pbonnx::tensor_proto::DataType) -> Self {
        match value {
            pbonnx::tensor_proto::DataType::Bfloat16 => DType::BFloat16,
            pbonnx::tensor_proto::DataType::Float => DType::Float,
            pbonnx::tensor_proto::DataType::Int8 => DType::Int8,
            pbonnx::tensor_proto::DataType::Int16 => DType::Int16,
            _ => unimplemented!(),
        }
    }
}

impl From<DType> for pbonnx::tensor_proto::DataType {
    fn from(val: DType) -> Self {
        match val {
            DType::BFloat16 => pbonnx::tensor_proto::DataType::Bfloat16,
            DType::Float => pbonnx::tensor_proto::DataType::Float,
            DType::Int8 => pbonnx::tensor_proto::DataType::Int8,
            DType::Int16 => pbonnx::tensor_proto::DataType::Int16,
            _ => unimplemented!(),
        }
    }
}

impl ExpertOnnxBuilder {
    fn tensor_pv(&self, p1: &str, v2: i64) -> pbonnx::type_proto::Value {
        pbonnx::type_proto::Value::TensorType(pbonnx::type_proto::Tensor {
            elem_type: Into::<pbonnx::tensor_proto::DataType>::into(self.data_type) as i32,
            shape: Some(pbonnx::TensorShapeProto {
                dim: vec![
                    pbonnx::tensor_shape_proto::Dimension {
                        value: Some(pbonnx::tensor_shape_proto::dimension::Value::DimParam(
                            p1.to_string(),
                        )),
                        ..Default::default()
                    },
                    pbonnx::tensor_shape_proto::Dimension {
                        value: Some(pbonnx::tensor_shape_proto::dimension::Value::DimValue(v2)),
                        ..Default::default()
                    },
                ],
            }),
        })
    }

    #[allow(unused)]
    fn tensor_vv(&self, v1: i64, v2: i64) -> pbonnx::type_proto::Value {
        pbonnx::type_proto::Value::TensorType(pbonnx::type_proto::Tensor {
            elem_type: Into::<pbonnx::tensor_proto::DataType>::into(self.data_type) as i32,
            shape: Some(pbonnx::TensorShapeProto {
                dim: vec![
                    pbonnx::tensor_shape_proto::Dimension {
                        value: Some(pbonnx::tensor_shape_proto::dimension::Value::DimValue(v1)),
                        ..Default::default()
                    },
                    pbonnx::tensor_shape_proto::Dimension {
                        value: Some(pbonnx::tensor_shape_proto::dimension::Value::DimValue(v2)),
                        ..Default::default()
                    },
                ],
            }),
        })
    }

    fn build_outputs(&self) -> Vec<pbonnx::ValueInfoProto> {
        let res = vec![pbonnx::ValueInfoProto {
            name: "output".to_string(),
            r#type: Some(pbonnx::TypeProto {
                value: Some(self.tensor_pv("batch_size", self.hidden_size)),
                ..Default::default()
            }),
            ..Default::default()
        }];
        res
    }

    fn build_inputs(&self) -> Vec<pbonnx::ValueInfoProto> {
        let inp = vec![pbonnx::ValueInfoProto {
            name: "input".to_string(),
            r#type: Some(pbonnx::TypeProto {
                value: Some(self.tensor_pv("batch_size", self.hidden_size)),
                ..Default::default()
            }),
            ..Default::default()
        }];

        inp
    }

    fn kv(&self, k: &str, v: &str) -> pbonnx::StringStringEntryProto {
        pbonnx::StringStringEntryProto {
            key: k.to_string(),
            value: v.to_string(),
        }
    }

    fn build_initializers(&self) -> Vec<pbonnx::TensorProto> {
        let init = vec![
            pbonnx::TensorProto {
                name: "onnx::MatMul_13".to_string(),
                data_type: Into::<pbonnx::tensor_proto::DataType>::into(self.data_type) as i32,
                dims: vec![self.hidden_size, self.intermediate_size],
                data_location: pbonnx::tensor_proto::DataLocation::External as i32,
                external_data: vec![self.kv("location", "#onnx::MatMul_13")],
                ..Default::default()
            },
            pbonnx::TensorProto {
                name: "onnx::MatMul_14".to_string(),
                data_type: Into::<pbonnx::tensor_proto::DataType>::into(self.data_type) as i32,
                data_location: pbonnx::tensor_proto::DataLocation::External as i32,
                dims: vec![self.hidden_size, self.intermediate_size],
                external_data: vec![self.kv("location", "#onnx::MatMul_14")],
                ..Default::default()
            },
            pbonnx::TensorProto {
                name: "onnx::MatMul_15".to_string(),
                data_type: Into::<pbonnx::tensor_proto::DataType>::into(self.data_type) as i32,
                data_location: pbonnx::tensor_proto::DataLocation::External as i32,
                external_data: vec![self.kv("location", "#onnx::MatMul_15")],
                dims: vec![self.intermediate_size, self.hidden_size],
                ..Default::default()
            },
        ];
        init
    }

    fn build_graph(&self) -> pbonnx::GraphProto {
        pbonnx::GraphProto {
            name: "main_graph".to_string(),
            node: self.build_nodes(),
            input: self.build_inputs(),
            output: self.build_outputs(),
            initializer: self.build_initializers(),
            ..Default::default()
        }
    }
    fn build_nodes(&self) -> Vec<pbonnx::NodeProto> {
        let node = vec![
            pbonnx::NodeProto {
                input: vec!["input".to_string(), "onnx::MatMul_13".to_string()],
                output: vec!["/gate_proj/MatMul_output_0".to_string()],
                name: "/gate_proj/MatMul".to_string(),
                op_type: "MatMul".to_string(),
                ..pbonnx::NodeProto::default()
            },
            pbonnx::NodeProto {
                input: vec!["/gate_proj/MatMul_output_0".to_string()],
                output: vec!["/act_fn/Sigmoid_output_0".to_string()],
                name: "/act_fn/Sigmoid".to_string(),
                op_type: "Sigmoid".to_string(),
                ..pbonnx::NodeProto::default()
            },
            pbonnx::NodeProto {
                input: vec![
                    "/gate_proj/MatMul_output_0".to_string(),
                    "/act_fn/Sigmoid_output_0".to_string(),
                ],
                output: vec!["/act_fn/Mul_output_0".to_string()],
                name: "/act_fn/Mul".to_string(),
                op_type: "Mul".to_string(),
                ..pbonnx::NodeProto::default()
            },
            pbonnx::NodeProto {
                input: vec!["input".to_string(), "onnx::MatMul_14".to_string()],
                output: vec!["/up_proj/MatMul_output_0".to_string()],
                name: "/up_proj/MatMul".to_string(),
                op_type: "MatMul".to_string(),
                ..pbonnx::NodeProto::default()
            },
            pbonnx::NodeProto {
                input: vec![
                    "/act_fn/Mul_output_0".to_string(),
                    "/up_proj/MatMul_output_0".to_string(),
                ],
                output: vec!["/Mul_output_0".to_string()],
                name: "/Mul".to_string(),
                op_type: "Mul".to_string(),
                ..pbonnx::NodeProto::default()
            },
            pbonnx::NodeProto {
                input: vec!["/Mul_output_0".to_string(), "onnx::MatMul_15".to_string()],
                output: vec!["output".to_string()],
                name: "/down_proj/MatMul".to_string(),
                op_type: "MatMul".to_string(),
                ..pbonnx::NodeProto::default()
            },
        ];
        node
    }
    pub fn build(&self) -> pbonnx::ModelProto {
        let graph = self.build_graph();

        pbonnx::ModelProto {
            ir_version: 9,
            opset_import: vec![pbonnx::OperatorSetIdProto {
                version: 20,
                ..Default::default()
            }],
            producer_name: "expert-kit".to_string(),
            producer_version: "dev".to_string(),
            doc_string: "expert onnx model file generated by expert-kit".to_string(),
            graph: Some(graph),
            ..Default::default()
        }
    }

    pub fn build_raw(&self) -> Vec<u8> {
        let msg = self.build();
        let mut buf = Vec::new();
        msg.encode(&mut buf).unwrap();
        buf
    }
}

#[cfg(test)]
mod test {

    use ort::session::Session;
    use prost::Message;

    use crate::{
        backend::{DType, Device, ort::NDArrayTensor},
        ffn::meta::ExpertWeight,
    };

    #[test]
    fn test_basic_export() {
        ort::init().commit().unwrap();
        let builder = super::ExpertOnnxBuilder {
            intermediate_size: 7168,
            hidden_size: 2048,
            data_type: DType::Float,
        };
        let model = builder.build();
        let raw = model.encode_to_vec();

        let rand_weight: ExpertWeight<NDArrayTensor<f32>> =
            ExpertWeight::from_rand_matmul(2048, 7168, DType::Float, Device::CPU);

        let session = Session::builder()
            .expect("failed to create session")
            .with_external_initializer("onnx::MatMul_13", rand_weight.up_w.into())
            .expect("should load up")
            .with_external_initializer("onnx::MatMul_14", rand_weight.gate_w.into())
            .expect("should load up")
            .with_external_initializer("onnx::MatMul_15", rand_weight.down_w.into())
            .expect("should load up")
            .commit_from_memory(raw.as_slice())
            .unwrap();

        assert!(session.inputs.len() == 1);
        assert!(session.outputs.len() == 1);
        let meta = session.metadata().expect("should have metadata");
        assert_eq!(meta.producer().expect("should have producer"), "expert-kit");
    }
}