/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

syntax = "proto3";

package af.proto;
enum DataType
{
    DT_UNDEFINED = 0;  // Used to indicate a DataType field has not been set.
    DT_FLOAT     = 1;  // float type
    DT_FLOAT16   = 2;  // fp16 type
    DT_INT8      = 3;  // int8 type
    DT_UINT8     = 4;  // uint8 type
    DT_INT16     = 5;  // int16 type
    DT_UINT16    = 6;  // uint16 type
    DT_INT32     = 7;  //
    DT_INT64     = 8;  // int64 type
    DT_UINT32    = 9;  // unsigned int32
    DT_UINT64    = 10;  // unsigned int64
    DT_BOOL      = 11;  // bool type
    DT_DOUBLE    = 12; // double type
    DT_STRING = 13;            // string type
    DT_DUAL_SUB_INT8 = 14;    /**< dual output int8 type */
    DT_DUAL_SUB_UINT8 = 15;    /**< dual output uint8 type */
    DT_COMPLEX64 = 16;         // complex64 type
    DT_COMPLEX128 = 17;        // complex128 type
    DT_QINT8 = 18;             // qint8 type
    DT_QINT16 = 19;            // qint16 type
    DT_QINT32 = 20;            // qint32 type
    DT_QUINT8 = 21;            // quint8 type
    DT_QUINT16 = 22;           // quint16 type
    DT_RESOURCE  = 23;         // resource type
    DT_STRING_REF = 24;        // string_ref type
    DT_DUAL      = 25;              /**< dual output type */
    DT_VARIANT = 26;           // variant type
    DT_BF16 = 27;              // bf16 type
    DT_INT4 = 28;              // int4 type
    DT_UINT1 = 29;             // uint1 type
    DT_INT2 = 30;              // int2 type
    DT_UINT2 = 31;             // uint2 type
    DT_COMPLEX32 = 32;         // complex32 type
    DT_HIFLOAT8 = 33;          // hifloat8 type
    DT_FLOAT8_E5M2 = 34;       // float8_e5m2 type
    DT_FLOAT8_E4M3FN = 35;     // float8_e4m3fn type
    DT_FLOAT8_E8M0 = 36;       // float8_e8m0 type
    DT_FLOAT6_E3M2 = 37;       // float6_e3m2 type
    DT_FLOAT6_E2M3 = 38;       // float6_e2m3 type
    DT_FLOAT4_E2M1 = 39;       // float4_e2m1 type
    DT_FLOAT4_E1M2 = 40;       // float4_e1m2 type
    DT_HIFLOAT4 = 41;          // hifloat4 type
}

message AttrDef
{
    message ListValue
    {
        enum ListValueType{
          VT_LIST_NONE = 0;
          VT_LIST_STRING = 1;
          VT_LIST_INT = 2;
          VT_LIST_FLOAT = 3;
          VT_LIST_BOOL = 4;
          VT_LIST_BYTES = 5;
          VT_LIST_TENSOR_DESC = 6;
          VT_LIST_TENSOR = 7;
          VT_LIST_GRAPH = 8;
          VT_LIST_NAMED_ATTRS = 9;
          VT_LIST_DATA_TYPE = 10;
        }
        repeated bytes s             = 2;                    // "list(string)"
        repeated int64 i             = 3;  // "list(int)"
        repeated float f             = 4;   // "list(float)"
        repeated bool  b             = 5;  // "list(bool)"
        repeated bytes bt            = 7;
        repeated TensorDescriptor td = 8;
        repeated TensorDef t         = 9;
        repeated GraphDef g          = 10;
	    repeated NamedAttrs na       = 11;
	    repeated int64 dt            = 12; // list ge::DataType

	    ListValueType val_type       = 20;
    }

    message ListListInt{
        message ListInt{
            repeated int64 list_i             = 1; // list int
        }
        repeated ListInt list_list_i             = 1; // list list int
    }

    message ListListFloat{
        message ListFloat{
            repeated float list_f             = 1; // list float
        }
        repeated ListFloat list_list_f             = 1; // list list float
    }

    oneof value
    {
        bytes            s    = 2;  // "string"
        int64            i    = 3;  // "int"
        float            f    = 4;  // "float"
        bool             b    = 5;  // "bool"
        bytes            bt   = 7;
        ListValue        list = 1;   // any "list(...)"
        NamedAttrs       func = 10;  // Used to support attr nesting
        TensorDescriptor td   = 11;  // GeTensorDesc type
        TensorDef        t    = 12;  // GeTensor type
        GraphDef         g    = 13;  // Graph type
        ListListInt      list_list_int  = 14;  // List List Int type
        int64            dt   = 15; // ge::DataType
        ListListFloat    list_list_float  = 16;  // List List Float type
        bytes            expression = 17;
    }
}

// replace AttrDef in ge_ir.proto in the future
// do not use this, only stub currently
message AttributeDef
{
  oneof value
  {
    bytes            s    = 1;  // "string"
    int64            i    = 2;  // "int"
    float            f    = 3;  // "float"
    bool             b    = 4;  // "bool"
    bytes            bt   = 5;  // "bytes";
  }
}

message OtherGroupDef {
  map<string, AttributeDef> attr = 1;
}

message TensorDescAttrGroupsDef {
  repeated string origin_symbol_shape = 1; // symbolic origin shape
  repeated string symbolic_value = 2; // symbolic value
}

message InputSourceDef {
  int32 input_data_idx = 1;
  int64 dim_idx = 2;
}

message ReplacementDef {
  string replace_expr = 1;
  int32 rank = 2;
}

message SymbolCheckInfoDef {
  string expr = 1;
  string file = 2;
  int64 line = 3;
  string dfx = 4;
}

message ShapeEnvSettingDef {
  bool specialize_zero_one = 1;
  int32 dynamic_mode = 2;
}

message SymbolInfoDef {
  repeated string symbols = 1;
}

message ShapeEnvAttrGroupsDef {
  map<string, int64> symbol_to_value = 1;
  map<int64, SymbolInfoDef> value_to_symbol = 2;
  map<string, InputSourceDef> symbol_to_source = 3;
  map<string, ReplacementDef> replacements = 4;
  repeated SymbolCheckInfoDef symbol_check_infos = 5;
  repeated SymbolCheckInfoDef symbol_assert_infos = 6;
  ShapeEnvSettingDef shape_setting = 7;
  uint64 unique_sym_id = 8;
}

message SchedInfoDef {
  int64 exec_order = 1;
  repeated int64 axis = 2;
  int64 loop_axis = 3;
  int32 exec_condition = 4;
}
message ApiInfoDef {
  int32 type = 1;
  int32 compute_type = 2;
  int32 unit = 3;
}

message MemAttrDef {
  int64 tensor_id = 1;
  int32 alloc_type = 2;
  int32 position = 3;
  int32 hardware = 4;
  repeated int64 buf_ids = 5;
  string name = 6;
  int64 reuse_id = 7;
}

message MemQueueAttrDef {
  int64 id = 1;
  int64 depth = 2;
  int64 buf_num = 3;
  string name = 4;
}

message MemBufAttrDef {
  int64 id = 1;
  string name = 2;
}

message MemOptAttrDef {
  int64 reuse_id = 1;
  int64 ref_tensor = 2;
  int64 merge_scope = 3;
}

message AxisDef {
  int64 id = 1;
  string name = 2;
  int32 axis_type = 3;
  bool bind_block = 4;
  string size = 5; // expression
  string align = 6;
  repeated int64 from = 7;
  int64 split_pair_other_id = 8;
  bool allow_oversize_axis = 9;
  bool allow_unaligned_tail = 10;
}

message AscendCIROpAttrGroupsDef {
  string name = 1;
  string type = 2;
}

message AscTensorAttrGroupsDef {
  int64 dtype = 1;
  repeated int64 axis_ids = 2;
  repeated string repeats = 3; // expression
  repeated string strides = 4; // expression
  repeated int64 vectorized_axis = 5;
  repeated string vectorized_strides = 6;
  MemAttrDef mem = 7;
  MemQueueAttrDef que = 8;
  MemBufAttrDef buf = 9;
  MemOptAttrDef opt = 10;
}

message AscGraphAttrGroupsDef {
  int64 tiling_key = 1;
  repeated AxisDef axis = 2;
  int64 type = 3;
  repeated string size_var = 4;
}

message AscIrAttrDef {
  map<string, AttrDef> attr = 1;
}


message TmpBufDescDef {
  string size = 1; // expression
  int64 life_time_axis_id = 2;
}

message TmpBufferGroupDef {
  TmpBufDescDef buf_desc = 1;
  MemAttrDef mem = 2;
  int64 id = 3;
}

message AscNodeAttrGroupsDef {
  string name = 1;
  string type = 2;
  SchedInfoDef sched = 3;
  ApiInfoDef api = 4;
  AscIrAttrDef ir_attr_def = 5;
  repeated TmpBufferGroupDef tmp_buffers = 6;
}

message AttrGroupDef {
  oneof attr_group {
    AscendCIROpAttrGroupsDef op_attr_group = 2;
    TensorDescAttrGroupsDef tensor_attr_group = 3;
    ShapeEnvAttrGroupsDef shape_env_attr_group = 4;
    AscGraphAttrGroupsDef asc_graph_attr_group = 5;
    AscNodeAttrGroupsDef asc_node_attr_group = 6;
    AscTensorAttrGroupsDef asc_tensor_attr_group = 7;
  }
}

message AttrGroups {
  OtherGroupDef other_group_def = 1;
  repeated AttrGroupDef attr_group_def = 2;
}

// A list of attr names and their values. The whole list is attached
// with a string name.  E.g., MatMul[T=float].
message NamedAttrs
{
    string               name = 1;
    map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
    repeated int64 dim = 1;  // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
    string   name   = 1;  // Optional parameter, tensor name

    DataType dtype  = 2;  // tensor datatype
    ShapeDef shape  = 3;  // Shape / dimension
    string   layout = 4;  // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

    bool has_out_attr = 9;
    int64 size = 10;
    int64 weight_size = 11;
    bool reuse_input = 12;
    bool output_tensor = 13;
    string device_type = 14;
    bool input_tensor =15;
    int64 real_dim_cnt = 16;
    int64 reuse_input_index = 17;
    int64 data_offset = 18;
    int64 cmps_size = 19;
    string cmps_tab = 20;
    int64 cmps_tab_offset = 21;

	map<string, AttrDef> attr = 5;  // Set of extra parameter fields
	AttrGroups attr_groups = 6;  // Set of attr groups
}

// GeTensor definition
message TensorDef
{
    TensorDescriptor desc = 1;  // Tensor description
    bytes            data = 2;  // Tensor data
}


// Operator description
message OpDef
{
    string name = 1;  // name
    string type = 2;  // type

    repeated string input = 5;  // input original op name + outgoing index. op_name:index

    map<string, AttrDef> attr = 10;  // Set of operator parameter fields
    AttrGroups attr_groups = 11;  // Set of attr groups

    bool has_out_attr = 20;
    int64 id = 21;
    int64 stream_id =22;
    repeated string input_name = 23;
    repeated string src_name = 24;
    repeated int64 src_index = 25;
    repeated string dst_name = 26;
    repeated int64 dst_index = 27;
    repeated int64 input_i = 28;
    repeated int64 output_i = 29;
    repeated int64 workspace = 30;
    repeated int64 workspace_bytes = 31;
    repeated bool is_input_const = 32;
    repeated TensorDescriptor input_desc = 33;
    repeated TensorDescriptor output_desc = 34;
    repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
    string name   = 1;   //  name

    repeated string input  = 4;  // Graph input
    repeated string output = 5;  // Graph output

    repeated OpDef op      = 6;  // List of operators

	map<string, AttrDef> attr = 11;  // Extended field
	AttrGroups attr_groups = 12;  // Set of attr groups
}

// model definition
message ModelDef
{
	string name         = 1;  // name
	uint32 version      = 2;  // IR Proto verion
	string custom_version = 3;  // User model version number, passed in by user

    repeated GraphDef graph = 7;  // Graph definition,graph[0] represents the main diagram in modeldef

    map<string, AttrDef> attr = 11;  // Extended field
    AttrGroups attr_groups = 12;  // Set of attr groups
}