// api.proto
// Definition of various TensorFlow protobuf messages for use with the TensorFlow API.
//
// Assembled from these relevant proto sources:
// https://github.com/google/protobuf/blob/master/src/google/protobuf/any.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/versions.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_def.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saver.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_model.proto
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/function.proto

syntax = "proto3";
package tensorflow;

message Any {
  string type_url = 1;
  bytes value = 2;
}

enum DataType {
  // Not a legal value for DataType.  Used to indicate a DataType field
  // has not been set.
  DT_INVALID = 0;

  // Data types that all computation devices are expected to be
  // capable to support.
  DT_FLOAT = 1;
  DT_DOUBLE = 2;
  DT_INT32 = 3;
  DT_UINT8 = 4;
  DT_INT16 = 5;
  DT_INT8 = 6;
  DT_STRING = 7;
  DT_COMPLEX64 = 8;  // Single-precision complex
  DT_INT64 = 9;
  DT_BOOL = 10;
  DT_QINT8 = 11;     // Quantized int8
  DT_QUINT8 = 12;    // Quantized uint8
  DT_QINT32 = 13;    // Quantized int32
  DT_BFLOAT16 = 14;  // Float32 truncated to 16 bits.  Only for cast ops.

  // Do not use!  These are only for parameters.  Every enum above
  // should have a corresponding value below (verified by types_test).
  DT_FLOAT_REF = 101;
  DT_DOUBLE_REF = 102;
  DT_INT32_REF = 103;
  DT_UINT8_REF = 104;
  DT_INT16_REF = 105;
  DT_INT8_REF = 106;
  DT_STRING_REF = 107;
  DT_COMPLEX64_REF = 108;
  DT_INT64_REF = 109;
  DT_BOOL_REF = 110;
  DT_QINT8_REF = 111;
  DT_QUINT8_REF = 112;
  DT_QINT32_REF = 113;
  DT_BFLOAT16_REF = 114;
}

message TensorShape {
  // One dimension of the tensor.
  message Dim {
    // Size of the tensor in that dimension.
    int64 size = 1;

    // Optional name of the tensor dimension.
    string name = 2;
  }

  // Dimensions of the tensor, such as {"input", 30}, {"output", 40} for a 30 x
  // 40 2D tensor.  The names are optional.
  //
  // The order of entries in "dim" matters: It indicates the layout of the
  // values in the tensor in-memory representation.
  //
  // The first entry in "dim" is the outermost dimension used to layout the
  // values, the last entry is the innermost dimension.  This matches the
  // in-memory layout of RowMajor Eigen tensors.
  repeated Dim dim = 2;

  bool unknown_rank = 3;
}

message Tensor {
  DataType dtype = 1;

  // Shape of the tensor.  TODO(touts): sort out the 0-rank issues.
  TensorShape tensor_shape = 2;

  // Only one of the representations below is set, one of "tensor_contents" and
  // the "xxx_val" attributes.  We are not using oneof because as oneofs cannot
  // contain repeated fields it would require another extra set of messages.

  // Version number.
  //
  // In version 0, if the "repeated xxx" representations contain only one
  // element, that element is repeated to fill the shape.  This makes it easy
  // to represent a constant Tensor with a single value.
  int32 version_number = 3;

  // Serialized content from TensorBase::Serialize() This representation can be
  // used for all tensor types.
  bytes tensor_content = 4;

  // Type specific representations that make it easy to create tensor protos in
  // all languages.  Only the representation corresponding to "dtype" can
  // be set.  The values hold the flattened representation of the tensor in
  // row major order.

  // DT_FLOAT.
  repeated float float_val = 5 [packed = true];

  // DT_DOUBLE.
  repeated double double_val = 6 [packed = true];

  // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
  repeated int32 int_val = 7 [packed = true];

  // DT_STRING
  repeated bytes string_val = 8;

  // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
  // and imaginary parts of i-th single precision complex.
  repeated float scomplex_val = 9 [packed = true];

  // DT_INT64
  repeated int64 int64_val = 10 [packed = true];

  // DT_BOOL
  repeated bool bool_val = 11 [packed = true];

  // DT_UINT32
  repeated uint32 uint32_val = 16 [packed = true];

  // DT_UINT64
  repeated uint64 uint64_val = 17 [packed = true];
}

message AttrValue {
  message ListValue {
    repeated bytes s = 2;
    repeated int64 i = 3 [packed = true];
    repeated float f = 4 [packed = true];
    repeated bool b = 5 [packed = true];
    repeated DataType type = 6 [packed = true];
    repeated TensorShape shape = 7;
    repeated Tensor tensor = 8;
    repeated NameAttrList func = 9;
  }

  oneof value {
    ListValue list = 1;
    bytes s = 2;
    int64 i = 3;
    float f = 4;
    bool b = 5;
    DataType type = 6;
    TensorShape shape = 7;
    Tensor tensor = 8;
    string placeholder = 9;
    NameAttrList func = 10;
  }
}

message NameAttrList {
  string name = 1;
  map<string, AttrValue> attr = 2;
}

message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;
  string device = 4;
  map<string, AttrValue> attr = 5;
}

message VersionDef {
  int32 producer = 1;
  int32 min_consumer = 2;
  repeated int32 bad_consumers = 3;
}

message GraphDef {
  repeated NodeDef node = 1;
  VersionDef versions = 4;
  FunctionDefLibrary library = 2;
}

message CollectionDef {
  message NodeList {
    repeated string value = 1;
  }
  message BytesList {
    repeated bytes value = 1;
  }
  message Int64List {
    repeated int64 value = 1 [packed = true];
  }
  message FloatList {
    repeated float value = 1 [packed = true];
  }
  message AnyList {
    repeated Any value = 1;
  }

  oneof kind {
    NodeList node_list = 1;
    BytesList bytes_list = 2;
    Int64List int64_list = 3;
    FloatList float_list = 4;
    AnyList any_list = 5;
  }
}

message SaverDef {
  string filename_tensor_name = 1;
  string save_tensor_name = 2;
  string restore_op_name = 3;
  int32 max_to_keep = 4;
  bool sharded = 5;
  float keep_checkpoint_every_n_hours = 6;

  enum CheckpointFormatVersion {
    LEGACY = 0;
    V1 = 1;
    V2 = 2;
  }
  CheckpointFormatVersion version = 7;
}

message TensorInfo {
  message CooSparse {
    string values_tensor_name = 1;
    string indices_tensor_name = 2;
    string dense_shape_tensor_name = 3;
  }

  oneof encoding {
    string name = 1;
    CooSparse coo_sparse = 4;
  }
  DataType dtype = 2;
  TensorShape tensor_shape = 3;
}

message SignatureDef {
  map<string, TensorInfo> inputs = 1;
  map<string, TensorInfo> outputs = 2;
  string method_name = 3;
}

message AssetFileDef {
  TensorInfo tensor_info = 1;
  string filename = 2;
}

message OpDef {
  string name = 1;

  message ArgDef {
    string name = 1;
    string description = 2;
    DataType type = 3;
    string type_attr = 4;    // if specified, attr must have type "type"
    string number_attr = 5;  // if specified, attr must have type "int"
    string type_list_attr = 6;
    bool is_ref = 16;
  }
  repeated ArgDef input_arg = 2;
  repeated ArgDef output_arg = 3;

  message AttrDef {
    string name = 1;
    string type = 2;
    AttrValue default_value = 3;
    string description = 4;
    bool has_minimum = 5;
    int64 minimum = 6;
    AttrValue allowed_values = 7;
  }
  repeated AttrDef attr = 4;

  message OpDeprecation {
    int32 version = 1;
    string explanation = 2;
  }
  OpDeprecation deprecation = 8;

  string summary = 5;
  string description = 6;
  bool is_commutative = 18;
  bool is_aggregate = 16;  // for things like add
  bool is_stateful = 17;  // for things like variables, queue
  bool allows_uninitialized_input = 19;  // for Assign, etc.
}

message OpList {
  repeated OpDef op = 1;
}

message MetaGraphDef {
  message MetaInfoDef {
    string meta_graph_version = 1;
    OpList stripped_op_list = 2;
    Any any_info = 3;
    repeated string tags = 4;
    string tensorflow_version = 5;
    string tensorflow_git_version = 6;
  }
  MetaInfoDef meta_info_def = 1;
  GraphDef graph_def = 2;
  SaverDef saver_def = 3;
  map<string, CollectionDef> collection_def = 4;
  map<string, SignatureDef> signature_def = 5;
  repeated AssetFileDef asset_file_def = 6;
}

message SavedModel {
  int64 saved_model_schema_version = 1;
  repeated MetaGraphDef meta_graphs = 2;
}

message FunctionDefLibrary {
  repeated FunctionDef function = 1;
  repeated GradientDef gradient = 2;
}

message FunctionDef {
  OpDef signature = 1;
  map<string, AttrValue> attr = 5;
  reserved 2;
  repeated NodeDef node_def = 3;
  map<string, string> ret = 4;
}

message GradientDef {
  string function_name = 1;  // The function name.
  string gradient_func = 2;  // The gradient function's name.
}
