syntax = "proto3";

package flyteidl.plugins.sagemaker;

import "google/protobuf/duration.proto";

option go_package = "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins";

// The input mode that the algorithm supports. When using the File input mode, SageMaker downloads
// the training data from S3 to the provisioned ML storage Volume, and mounts the directory to docker
// volume for training container. When using Pipe input mode, Amazon SageMaker streams data directly
// from S3 to the container.
// See: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html
// For the input modes that different SageMaker algorithms support, see:
// https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html
message InputMode {
    enum Value {
        FILE = 0;
        PIPE = 1;
    }
}

// The algorithm name is used for deciding which pre-built image to point to.
// This is only required for use cases where SageMaker's built-in algorithm mode is used.
// While we currently only support a subset of the algorithms, more will be added to the list.
// See: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
message AlgorithmName {
    enum Value {
        CUSTOM = 0;
        XGBOOST = 1;
    }
}


// Specifies the type of file for input data. Different SageMaker built-in algorithms require different file types of input data
// See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html
// https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html
message InputContentType {
    enum Value {
        TEXT_CSV = 0;
    }
}

// Specifies a metric that the training algorithm writes to stderr or stdout.
// This object is a pass-through.
// See this for details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_MetricDefinition.html
message MetricDefinition {
    // User-defined name of the metric
    string name = 1;
    // SageMaker hyperparameter tuning parses your algorithm’s stdout and stderr streams to find algorithm metrics
    string regex = 2;
}


// Specifies the training algorithm to be used in the training job
// This object is mostly a pass-through, with a couple of exceptions include: (1) in Flyte, users don't need to specify
// TrainingImage; either use the built-in algorithm mode by using Flytekit's Simple Training Job and specifying an algorithm
// name and an algorithm version or (2) when users want to supply custom algorithms they should set algorithm_name field to
// CUSTOM. In this case, the value of the algorithm_version field has no effect
// For pass-through use cases: refer to this AWS official document for more details
// https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html
message AlgorithmSpecification {
    // The input mode can be either PIPE or FILE
    InputMode.Value input_mode = 1;

    // The algorithm name is used for deciding which pre-built image to point to
    AlgorithmName.Value algorithm_name = 2;
    // The algorithm version field is used for deciding which pre-built image to point to
    // This is only needed for use cases where SageMaker's built-in algorithm mode is chosen
    string algorithm_version = 3;

    // A list of metric definitions for SageMaker to evaluate/track on the progress of the training job
    // See this: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html
    // and this: https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html
    repeated MetricDefinition metric_definitions = 4;

    // The content type of the input
    // See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html
    // https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html
    InputContentType.Value input_content_type = 5;
}

// When enabling distributed training on a training job, the user should use this message to tell Flyte and SageMaker
// what kind of distributed protocol he/she wants to use to distribute the work.
message DistributedProtocol {
    enum Value {
        // Use this value if the user wishes to use framework-native distributed training interfaces.
        // If this value is used, Flyte won't configure SageMaker to initialize unnecessary components such as
        // OpenMPI or Parameter Server.
        UNSPECIFIED = 0;
        // Use this value if the user wishes to use MPI as the underlying protocol for her distributed training job
        // MPI is a framework-agnostic distributed protocol. It has multiple implementations. Currently, we have only
        // tested the OpenMPI implementation, which is the recommended implementation for Horovod.
        MPI = 1;
    }
}

// TrainingJobResourceConfig is a pass-through, specifying the instance type to use for the training job, the
// number of instances to launch, and the size of the ML storage volume the user wants to provision
// Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html
message TrainingJobResourceConfig {
    // The number of ML compute instances to use. For distributed training, provide a value greater than 1.
    int64 instance_count = 1;
    // The ML compute instance type
    string instance_type = 2;
    // The size of the ML storage volume that you want to provision.
    int64 volume_size_in_gb = 3;
    // When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
    // If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
    // field should be set to the corresponding enum value
    DistributedProtocol.Value distributed_protocol = 4;
}

// The spec of a training job. This is mostly a pass-through object
// https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html
message TrainingJob {
    AlgorithmSpecification algorithm_specification = 1;
    TrainingJobResourceConfig training_job_resource_config = 2;
}
