/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

syntax = "proto3";

package ge.mobile.proto;

enum DataType
{
    DT_UNDEFINED = 0;  // Used to indicate a DataType field has not been set.
    DT_FLOAT     = 1;  // float type
    DT_FLOAT16   = 2;  // fp16 type
    DT_INT8      = 3;  // int8 type
    DT_UINT8     = 4;  // uint8 type
    DT_INT16     = 5;  // int16 type
    DT_UINT16    = 6;  // uint16 type
    DT_INT32     = 7;  // 
    DT_INT64     = 8;  // int64 type
    DT_UINT32    = 9;  // unsigned int32
    DT_UINT64    = 10;  // unsigned int64
    DT_BOOL      = 11;  // bool type
    DT_DOUBLE    = 12; // double type
    DT_FLOAT8_E4M3FN  = 35; // float8 type
}

message AttrDef
{
    message ListValue
    {
        enum ListValueType{
          VT_LIST_NONE = 0;
          VT_LIST_STRING = 1;
          VT_LIST_INT = 2;
          VT_LIST_FLOAT = 3;
          VT_LIST_BOOL = 4;
          VT_LIST_BYTES = 5;
          VT_LIST_TENSOR_DESC = 6;
          VT_LIST_TENSOR = 7;
          VT_LIST_GRAPH = 8;
          VT_LIST_NAMED_ATTRS = 9;
        }
        repeated bytes s             = 2;                    // "list(string)"
        repeated int64 i             = 3;  // "list(int)"
        repeated float f             = 4;   // "list(float)"
        repeated bool  b             = 5;  // "list(bool)"
        repeated bytes bt            = 7;
        repeated TensorDescriptor tf = 8;
        repeated TensorDef t         = 9;
        repeated GraphDef g          = 10;
	    repeated NamedAttrs na       = 11;

	    ListValueType val_type       = 20;
    }

    message ListListInt{
        message ListInt{
            repeated int64 list_i             = 1; // list int
        }
        repeated ListInt list_list_i             = 1; // list list int
    }

    message ListListFloat{
        message ListFloat{
            repeated float list_f             = 1; // list float
        }
        repeated ListFloat list_list_f             = 1; // list list float
    }

    oneof value
    {
        bytes            s    = 2;  // "string"
        int64            i    = 3;  // "int"
        float            f    = 4;  // "float"
        bool             b    = 5;  // "bool"
        bytes            bt   = 7;
        ListValue        list = 1;   // any "list(...)"
        NamedAttrs       func = 10;  // 用于支持attr嵌套
        TensorDescriptor td   = 11;  // TensorDesc 类型
        TensorDef        t    = 12;  // Tensor 类型
        GraphDef         g    = 13;  // Graph 类型
        ListListInt      list_list_int  = 14;  // List List Int type
        ListListFloat    list_list_float  = 15;  // List List Float type
    }
}

// A list of attr names and their values. The whole list is attached
// with a string name.  E.g., MatMul[T=float].
message NamedAttrs
{
    string               name = 1;
    map<string, AttrDef> attr = 2;
}

// 形状/维度描述,使用row-major序
message ShapeDef
{
    repeated int64 dim = 1;  // 每个维度的大小
}

// 多维数据描述
message TensorDescriptor
{
    string   name   = 1;  // 可选参数,tensor的name

    DataType dtype  = 2;  // tensor的类型
    ShapeDef shape  = 3;  // 形状/维度
    string   layout = 4;  // Tensor格式描述,如"NCHW", "NHWC", "CHW", "ND"

    bool has_out_attr = 9;
    int64 size = 10;
    int64 weight_size = 11;
    bool reuse_input = 12;
    bool output_tensor = 13;
    string device_type = 14;
    bool input_tensor =15;
    int64 real_dim_cnt = 16;
    int64 reuse_input_index = 17;
    int64 data_offset = 18;
    int64 cmps_size = 19;
    string cmps_tab = 20;
    int64 cmps_tab_offset = 21;

	map<string, AttrDef> attr = 5;  // 额外参数字段集合
}

// Tensor 定义
message TensorDef
{
    TensorDescriptor desc = 1;  // Tensor描述
    bytes            data = 2;  // Tensor数据
}


// 算子描述
message OpDef
{
    string name = 1;  // 名称
    string type = 2;  // 类型

    repeated string input = 5;  // 输入对应的原始op名称+出边索引。op_name:index

    map<string, AttrDef> attr = 10;  // 算子参数字段集合

    bool has_out_attr = 20;
    int64 id = 21;
    int64 stream_id =22;
    repeated string input_name = 23;
    repeated string src_name = 24;
    repeated int64 src_index = 25;
    repeated string dst_name = 26;
    repeated int64 dst_index = 27;
    repeated int64 input_i = 28;
    repeated int64 output_i = 29;
    repeated int64 workspace = 30;
    repeated int64 workspace_bytes = 31;
    repeated bool is_input_const = 32;
    repeated TensorDescriptor input_desc = 33;
    repeated TensorDescriptor output_desc = 34;
}

// Graph 定义
message GraphDef
{
    string name   = 1;   //  名称

    repeated string input  = 4;  // Graph的输入
    repeated string output = 5;  // Graph的输出

    repeated OpDef op      = 6;  // 算子列表

	map<string, AttrDef> attr = 11;  //扩展字段
}

// 模型定义
message ModelDef
{
	string name         = 1;  // 名称
	uint32 version      = 2;  // IR Proto版本号
	string custom_version = 3;  // 用户模型版本号,由用户传入

    repeated GraphDef graph = 7;  // Graph定义,graph[0]表示ModelDef中的主图

    map<string, AttrDef> attr = 11;  //扩展字段
}