syntax = "proto3";

package flyteidl.plugins.kubeflow;

option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow";

import "flyteidl/core/tasks.proto";
import "flyteidl/plugins/kubeflow/common.proto";

// Proto for plugin that enables distributed training using https://github.com/kubeflow/mpi-operator
message DistributedMPITrainingTask {
  // Worker replicas spec
  DistributedMPITrainingReplicaSpec worker_replicas = 1;

  // Master replicas spec
  DistributedMPITrainingReplicaSpec launcher_replicas = 2;
  
  // RunPolicy encapsulates various runtime policies of the distributed training
  // job, for example how to clean up resources and how long the job can stay
  // active.
  RunPolicy run_policy = 3;

  // Number of slots per worker
  int32 slots = 4; 
}

// Replica specification for distributed MPI training
message DistributedMPITrainingReplicaSpec {
  // 1~4 deprecated. Use common instead.
  // Number of replicas
  int32 replicas = 1 [deprecated = true];

  // Image used for the replica group
  string image = 2 [deprecated = true];

  // Resources required for the replica group
  core.Resources resources = 3 [deprecated = true];

  // Restart policy determines whether pods will be restarted when they exit
  RestartPolicy restart_policy = 4 [deprecated = true];

  // MPI sometimes requires different command set for different replica groups
  repeated string command = 5;

  // The common replica spec
  CommonReplicaSpec common = 6;
}
