/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
  repeated FunctionDef function = 1;
  repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
//   * device spec, etc.
message FunctionDef {
  // The definition of the function's name, arguments, return values,
  // attrs etc.
  OpDef signature = 1;

  // Attributes specific to this function definition.
  map<string, AttrValue> attr = 5;

  // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
  reserved 2;

  // In both of the following fields, there is the need to specify an
  // output that is used as either the input to another node (in
  // `node_def`) or as a return value of the function (in `ret`).
  // Unlike the NodeDefs in GraphDef, we need to be able to specify a
  // list in some cases (instead of just single outputs).  Also, we
  // need to be able to deal with lists of unknown length (so the
  // output index may not be known at function definition time).  So
  // we use the following format instead:
  // * "fun_in" where "fun_in" is the name of a function input arg in
  //   the `signature` field above.  This represents that input, whether
  //   it is a single tensor or a list.
  // * "fun_in:0" gives the first element of a function input arg (a
  //   non-list input is considered a list of length 1 for these
  //   purposes).
  // * "node:out" where "node" is the name of a node in `node_def` and
  //   "out" is the name one of its op's output arguments (the name
  //   comes from the OpDef of the node's op). This represents that
  //   node's output, whether it is a single tensor or a list.
  //   Note: We enforce that an op's output arguments are never
  //   renamed in the backwards-compatibility test.
  // * "node:out:0" gives the first element of a node output arg (a
  //   non-list output is considered a list of length 1 for these
  //   purposes).
  //
  // NOT CURRENTLY SUPPORTED (but may be in the future):
  // * "node:out:-1" gives last element in a node output list
  // * "node:out:1:" gives a list with all but the first element in a
  //   node output list
  // * "node:out::-1" gives a list with all but the last element in a
  //   node output list

  // The body of the function.  Unlike the NodeDefs in a GraphDef, attrs
  // may have values of type `placeholder` and the `input` field uses
  // the "output" format above.

  // By convention, "op" in node_def is resolved by consulting with a
  // user-defined library first. If not resolved, "func" is assumed to
  // be a builtin op.
  repeated NodeDef node_def = 3;

  // A mapping from the output arg names from `signature` to the
  // outputs from `node_def` that should be returned by the function.
  map<string, string> ret = 4;
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
//    (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
//    (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
//                                      dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
  string function_name = 1;  // The function name.
  string gradient_func = 2;  // The gradient function's name.
}