/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

syntax = "proto3";

package domi;

enum TargetType
{
    MINI = 0;
    TINY = 1;
    LITE = 2;
}

// offline model
message ModelDef {
    string name = 1;
    uint32 version = 2;

    uint64 memory_size = 10;
    uint32 stream_num = 11;
    uint32 event_num = 12;
    uint64 weight_size = 13;
    uint32 label_num = 15;
    repeated OpDef op = 20;
    TargetType target_type = 23;

    map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
    string name = 1;
    string type = 2;

    uint32 id = 3;
    uint32 stream_id = 4;

    repeated string input_name = 5;

    repeated string src_name = 8;
    repeated int32 src_index = 9;
    repeated int64 input = 10;
    repeated int64 output = 11;
    repeated TensorDescriptor input_desc = 12;
    repeated TensorDescriptor output_desc = 13;
    repeated WeightDef weights = 14;
    repeated string dst_name = 15;
    repeated int32 dst_index = 16;

    repeated int64 workspace = 20;
    repeated uint32 workspace_bytes = 21;

    repeated string weight_name = 22;
    repeated bool is_input_const = 23;

    map<string, AttrDef> attr = 30;

    QuantizeFactorParams quantize_factor = 31;

    oneof op_params {
        // start at 100 here
        SendOpParams sender_param = 100;
        RecvOpParams receiver_param = 200;
        ConvolutionOpParams convolution_param = 300;
        PoolingOpParams pooling_param = 400;
        EltwiseOpParams eltwise_param = 500;
        BatchNormOpParams batchnorm_param = 600;
        ScaleOpParams scale_param = 700;
        FullConnectionOpParams full_connection_param = 800;
        SoftmaxOpParams softmax_param = 900;
        ActivationOpParams activation_param = 1000;
        ReshapeOpParams reshape_param = 1100;
    }
};

message SendOpParams {
    uint32 event_id = 1;
};

message RecvOpParams {
    uint32 event_id = 1;
};

enum QuantizeScaleType
{
    VECTOR_SCALE = 0;
    SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
    NORMAL_MODE = 0;
    SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
    NON_OFFSET_ALGO = 0;
    HALF_OFFSET_ALGO = 1;
    ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
    QuantizeScaleMode scale_mode = 1;
    bytes scale_value = 2;
    int64 scale_offset = 3;
    bytes offset_data_value = 4;
    int64 offset_data_offset = 5;
    bytes offset_weight_value = 6;
    int64 offset_weight_offset = 7;
    bytes offset_pad_value = 8;
    int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
    bytes offsetw = 1;
    int64 offsetw_offset = 2;
    bytes offsetd = 3;
    int64 offsetd_offset = 4;
    bytes scalereq = 5;
    int64 scaledreq_offset = 6;
    bytes offsetdnext = 7;
    int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
    QuantizeAlgorithm quantize_algo = 1;
    QuantizeScaleType scale_type = 2;
    QuantizeFactor quantize_param = 3;
    QuantizeFactor dequantize_param = 4;
    QuantizeFactor requantize_param = 5;
    QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
    int32 mode = 1;
    int32 algo = 2;
    int32 pad_mode = 3;
    uint32 group = 4;
    uint32 num_output = 5;

    repeated uint32 pad = 10;
    repeated uint32 stride = 11;
    repeated uint32 dilation = 12;
    repeated uint32 kernel = 13;

    float alpha = 20;
    float beta = 21;

    WeightDef filter = 40;
    WeightDef bias = 41;

    bool relu_flag = 62;
    repeated uint32 adj = 70;
    repeated uint32 target_shape = 71;
    repeated uint32 before_pad = 72;
};

message PoolingOpParams {
    int32 mode = 1;
    int32 nan_opt = 2;
    int32 pad_mode = 3;
    bool global_pooling = 4;

    repeated uint32 window = 10;
    repeated uint32 pad = 11;
    repeated uint32 stride = 12;
    bool ceil_mode = 13;
    int32 data_mode  = 14;

    float alpha = 20;
    float beta = 21;
    repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
    int32 mode = 1;
    repeated float coeff = 2;
    float alpha = 3;
    float beta = 4;
    repeated WeightDef weight = 5;
    bool relu_flag = 6;
};

message ActivationOpParams {
    int32 mode = 1;
    float coef = 2;
    float alpha = 3;
    float beta = 4;
};

message BatchNormOpParams {
    int32 mode = 1;

    float alpha = 2;
    float beta = 3;
    double epsilon = 4;//optinal,[default = 1e-5]
    bool use_global_stats = 5; //optinal,by default true,testing mode
    float  moving_average_fraction = 6; //optinal,[default = .999];

    WeightDef estimated_mean = 7;
    WeightDef estimated_variance = 8;

    WeightDef scale = 9;
    WeightDef bias = 10;
};

message ScaleOpParams {
    WeightDef scale = 1;
    WeightDef bias = 2;
};

message ReshapeOpParams {
    float alpha = 1;
    float beta = 2;
    ShapeDef shape = 3;
    int32 axis = 4;
    int32 num_axes = 5;
    int32 format = 6;
};

message SoftmaxOpParams {
    int32 algo = 1;
    int32 mode = 2;
    float alpha = 3;
    float beta = 4;
};

message FullConnectionOpParams {
    WeightDef filter = 1;
    WeightDef bias = 2;
    uint32 num_output = 3;
    bool relu_flag = 12;
};

message FlattenOpParams {
    float alpha = 1;
    float beta = 2;
    int32 start_axis = 3;
    int32 end_axis = 4;
}

message AddLimitedOpParams {
    float alpha = 1;
    float beta = 2;
    int32 axis = 3;
    bool broadcast = 4;

    repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
    float alpha = 1;
    float beta = 2;
    int32 axis = 3;
    bool broadcast = 4;

    repeated WeightDef weight = 10;
};

message AddOpParams {
    float alpha = 1;
    float beta = 2;

    repeated WeightDef weight = 10;
};

message MulOpParams {
    float alpha = 1;
    float beta = 2;

    repeated WeightDef weight = 10;
};

message SubOpParams {
    float alpha = 1;
    float beta = 2;

    repeated WeightDef weight = 10;
};

message BiasAddOpParams {
    float alpha = 1;
    float beta = 2;

    WeightDef bias = 10;
};

message MatMulOpParams {
    float alpha = 1;
    float beta = 2;
    bool transposeX = 3;
    bool transposeW = 4;

    WeightDef filter = 10;
    WeightDef bias = 12;
};

message RsqrtOpParams {
    float alpha = 1;
    float beta = 2;
};


message WeightDef {
    int32 format = 1;
    int32 data_type = 2;
    ShapeDef shape = 3;
    bytes data = 4;
    int64 data_offset = 5;
    uint32 cmps_size = 6;
    bytes cmps_tab = 7;
    int64 cmps_tab_offset = 10;
    CompressInfo cmps_info = 8;
    AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
    repeated int64 dim = 1;
}

enum DeviceType {
  NPU = 0;                   // In default, we will use NPU.
  CPU = 1;                   // CPU
}

message AllOffsetQuantizeInfo {
	float scale  = 1;
	int32 offset = 2;
} 

message TensorDescriptor {
    int32 format = 1;
    int32 data_type = 2;
    repeated int64 dim = 3;
    uint32 size = 4;
    bool reuse_input = 5;
    bool output_tensor = 7;
    DeviceType device_type = 8;
    bool input_tensor = 9;
    uint32 real_dim_cnt = 10;
    uint32 reuse_input_index = 11;
    AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
    int32 blockRow = 1;                             // block row
    int32 blockCol = 2;                             // block col
    int32 fractalK = 3;                             // fractal K
    int32 fractalN = 4;                             // fractal N
    int32 lastFractalK = 5;                         // K of last fractal
    int32 lastFractalN = 6;                         // N of last fractal
    int32 cubeSize = 7;                             // cube's length
    int32 loadDir = 8;                              // data load directtiono 0:col load 1:row load
}

message AttrDef {
  message ListValue {
    repeated string s = 2;                        // "list(string)"
    repeated int64 i = 3 [packed = true];        // "list(int)"
    repeated float f = 4 [packed = true];        // "list(float)"
    repeated bool b = 5 [packed = true];         // "list(bool)"
    repeated uint32 u = 6 [packed = true];         // "list(uint)"
    repeated bytes bt = 7;
  }

  oneof value {
    string s = 2;                // "string"
    int64 i = 3;                 // "int"
    float f = 4;                 // "float"
    bool b = 5;                  // "bool"
    uint32 u = 6;                // "uint32"
    bytes bt = 7;
    ListValue list = 1;          // any "list(...)"
    NamedAttrs func = 10;
  }
}

// A list of attr names and their values. The whole list is attached
// with a string name.  E.g., MatMul[T=float].
message NamedAttrs {
  string name = 1;
  map<string, AttrDef> attr = 2;
}