/**
 * Copyright 2019-2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

syntax = "proto3";

package debugger;

import "debug_graph.proto";

service EventListener {
  rpc WaitCMD (Metadata) returns (EventReply) {};
  rpc SendMetadata (Metadata) returns (EventReply) {};
  rpc SendGraph (stream Chunk) returns (EventReply) {};
  rpc SendTensors (stream TensorProto) returns (EventReply) {};
  rpc SendTensorBase (stream TensorBase) returns (EventReply) {};
  rpc SendTensorStats (stream TensorSummary) returns (EventReply) {};
  rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {};
  rpc SendMultiGraphs (stream Chunk) returns (EventReply) {};
  rpc SendHeartbeat (Heartbeat) returns (EventReply) {};
}

message Metadata {
  string device_name = 1;
  int32 cur_step = 2;
  // define the backend is 'GPU' or "Ascend"
  string backend = 3;
  // the full name of current node
  string cur_node = 4;
  // check if training is done.
  bool training_done = 5;
  // the number of total graphs
  int32 graph_num = 6;
  // mindspore version
  string ms_version = 7;
}

message Chunk {
  bytes buffer = 1;
  bool finished = 2;
}

message EventReply {
  enum Status {
    OK = 0;
    FAILED = 1;
    PENDING = 2;
  }

  Status status = 1;

  oneof cmd {
    bool exit = 2;
    RunCMD run_cmd = 3;
    SetCMD set_cmd = 4;
    ViewCMD view_cmd = 5;
    bool version_matched = 6;
  }
}

message RunCMD {
  // step level or node level.  "step" or "node"
  string run_level = 1;
  oneof cmd {
    int32 run_steps = 2;
    // the next node full name
    string node_name = 3;
  }
}

message SetCMD {
  repeated WatchNode watch_nodes = 1;
  WatchCondition watch_condition = 2;
  bool delete = 3;
  int32 id = 4;
}

message ViewCMD {
  repeated TensorProto tensors = 1;
  enum Level {
    value = 0;
    statistics = 1;
    base = 2;
  }
  Level level = 2;
}

message WatchCondition {
  enum Condition {
    nan = 0;
    inf = 1;
    overflow = 2;
    max_gt = 3;
    max_lt = 4;
    min_gt = 5;
    min_lt = 6;
    max_min_gt = 7;
    max_min_lt = 8;
    mean_gt = 9;
    mean_lt = 10;
    sd_gt = 11;
    sd_lt = 12;
    tensor_general_overflow = 13;
    tensor_initialization = 14;
    tensor_too_large = 15;
    tensor_too_small = 16;
    tensor_all_zero = 17;
    tensor_change_too_large = 18;
    tensor_change_too_small = 19;
    tensor_not_changed = 20;
    tensor_range = 21;
  }
  Condition condition = 1;
  float value = 2;
  message Parameter {
    string name = 1;
    bool disabled = 2;
    double value = 3;
    bool hit = 4;  // Whether this parameter is hit when checking tensor.
    double actual_value = 5;
  }
  repeated Parameter params = 4;
}

message WatchNode {
  string node_name = 1;
  string node_type = 2;
}

message WatchpointHit {
  TensorProto tensor = 1;
  WatchCondition watch_condition = 2;
  int32 id = 3;
  int32 error_code = 4;
}

message Heartbeat {
  string message = 1;
  int32 period = 2;
}

message TensorSummary{
 TensorBase tensor_base = 1;
 Statistics statistics = 2;
}

message Statistics {
 bool is_bool = 1;
 float max_value = 2;
 float min_value = 3;
 float avg_value = 4;
 int32 count = 5;
 int32 neg_zero_count = 6;
 int32 pos_zero_count = 7;
 int32 nan_count = 8;
 int32 neg_inf_count = 9;
 int32 pos_inf_count = 10;
 int32 zero_count = 11;
}

message TensorBase{
 int32 data_type = 1;
 repeated int64 shape = 2;
 int64 data_size = 3;
}