syntax = "proto3";

package flyteidl.plugins;

option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins";

// Custom proto for torch elastic config for distributed training using 
// https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
message ElasticConfig {
  string rdzv_backend = 1;
  int32 min_replicas = 2;
  int32 max_replicas = 3;
  int32 nproc_per_node = 4;
  int32 max_restarts = 5;
}

// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator
message DistributedPyTorchTrainingTask {
  // number of worker replicas spawned in the cluster for this job
  int32 workers = 1;

  // config for an elastic pytorch job
  // 
  ElasticConfig elastic_config = 2;
}
