syntax = "proto3";

package flyteidl.plugins.kubeflow;

option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow";

import "flyteidl/core/tasks.proto";
import "flyteidl/plugins/kubeflow/common.proto";

// Proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator
message DistributedTensorflowTrainingTask {
  // Worker replicas spec
  DistributedTensorflowTrainingReplicaSpec worker_replicas = 1;

  // Parameter server replicas spec
  DistributedTensorflowTrainingReplicaSpec ps_replicas = 2;

  // Chief replicas spec
  DistributedTensorflowTrainingReplicaSpec chief_replicas = 3;

  // RunPolicy encapsulates various runtime policies of the distributed training
  // job, for example how to clean up resources and how long the job can stay
  // active.
  RunPolicy run_policy = 4;

  // Evaluator replicas spec
  DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5;
}

message DistributedTensorflowTrainingReplicaSpec {
  // 1~4 deprecated. Use common instead.
  // Number of replicas
  int32 replicas = 1 [deprecated = true];

  // Image used for the replica group
  string image = 2 [deprecated = true];

  // Resources required for the replica group
  core.Resources resources = 3 [deprecated = true];

  // Restart policy determines whether pods will be restarted when they exit
  RestartPolicy restart_policy = 4 [deprecated = true];

  // The common replica spec
  CommonReplicaSpec common = 5;
}
