syntax = "proto3";

package flyteidl.plugins.kubeflow;

option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow";

import "flyteidl/core/tasks.proto";
import "flyteidl/plugins/kubeflow/common.proto";

// Custom proto for torch elastic config for distributed training using 
// https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
message ElasticConfig {
  string rdzv_backend = 1;
  int32 min_replicas = 2;
  int32 max_replicas = 3;
  int32 nproc_per_node = 4;
  int32 max_restarts = 5;
}

// Proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator
message DistributedPyTorchTrainingTask {
  // Worker replicas spec
  DistributedPyTorchTrainingReplicaSpec worker_replicas = 1;

  // Master replicas spec, master replicas can only have 1 replica
  DistributedPyTorchTrainingReplicaSpec master_replicas = 2;

  // RunPolicy encapsulates various runtime policies of the distributed training
  // job, for example how to clean up resources and how long the job can stay
  // active.
  RunPolicy run_policy = 3;

  // config for an elastic pytorch job
  ElasticConfig elastic_config = 4;
}

message DistributedPyTorchTrainingReplicaSpec {
  // 1~4 deprecated. Use common instead.
  // Number of replicas
  int32 replicas = 1 [deprecated = true];

  // Image used for the replica group
  string image = 2 [deprecated = true];

  // Resources required for the replica group
  core.Resources resources = 3 [deprecated = true];

  // Restart policy determines whether pods will be restarted when they exit
  RestartPolicy restart_policy = 4 [deprecated = true];

  // The common replica spec
  CommonReplicaSpec common = 5;
}
