import {
    Bedrock, CreateModelCustomizationJobCommand, FoundationModelSummary, GetModelCustomizationJobCommand,
    GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand
} from "@aws-sdk/client-bedrock";
import { BedrockRuntime, ConverseRequest, ConverseResponse, ConverseStreamOutput, InferenceConfiguration, Tool } from "@aws-sdk/client-bedrock-runtime";
import { S3Client } from "@aws-sdk/client-s3";
import { AwsCredentialIdentity, Provider } from "@aws-sdk/types";
import {
    AbstractDriver, AIModel, Completion, CompletionChunkObject, DataSource, DriverOptions, EmbeddingsOptions, EmbeddingsResult,
    ExecutionOptions, ExecutionTokenUsage, ImageGeneration, Modalities, PromptOptions, PromptSegment,
    TextFallbackOptions, ToolDefinition, ToolUse, TrainingJob, TrainingJobStatus, TrainingOptions,
    BedrockClaudeOptions, BedrockPalmyraOptions, getMaxTokensLimit, NovaCanvasOptions,
    modelModalitiesToArray, getModelCapabilities
} from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import { formatNovaPrompt, NovaMessagesPrompt } from "@llumiverse/core/formatters";
import { LRUCache } from "mnemonist";
import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js";
import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
import { forceUploadFile } from "./s3.js";

const supportStreamingCache = new LRUCache<string, boolean>(4096);

enum BedrockModelType {
    FoundationModel = "foundation-model",
    InferenceProfile = "inference-profile",
    CustomModel = "custom-model",
    Unknown = "unknown",
};

function converseFinishReason(reason: string | undefined) {
    //Possible values:
    //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
    if (!reason) return undefined;
    switch (reason) {
        case 'end_turn': return "stop";
        case 'max_tokens': return "length";
        default: return reason;
    }
}

export interface BedrockModelCapabilities {
    name: string;
    canStream: boolean;
}

export interface BedrockDriverOptions extends DriverOptions {
    /**
     * The AWS region
     */
    region: string;
    /**
     * The bucket name to be used for training.
     * It will be created if does not already exist.
     */
    training_bucket?: string;

    /**
     * The role ARN to be used for training
     */
    training_role_arn?: string;

    /**
     * The credentials to use to access AWS
     */
    credentials?: AwsCredentialIdentity | Provider<AwsCredentialIdentity>;
}

export type BedrockPrompt = NovaMessagesPrompt | ConverseRequest;

export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockPrompt> {

    static PROVIDER = "bedrock";

    provider = BedrockDriver.PROVIDER;

    private _executor?: BedrockRuntime;
    private _service?: Bedrock;
    private _service_region?: string;

    constructor(options: BedrockDriverOptions) {
        super(options);
        if (!options.region) {
            throw new Error("No region found. Set the region in the environment's endpoint URL.");
        }
    }

    getExecutor() {
        if (!this._executor) {
            this._executor = new BedrockRuntime({
                region: this.options.region,
                credentials: this.options.credentials,

            });
        }
        return this._executor;
    }

    getService(region: string = this.options.region) {
        if (!this._service || this._service_region != region) {
            this._service = new Bedrock({
                region: region,
                credentials: this.options.credentials,
            });
            this._service_region = region;
        }
        return this._service;
    }

    protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<BedrockPrompt> {
        if (opts.model.includes("canvas")) {
            return await formatNovaPrompt(segments, opts.result_schema);
        }
        return await formatConversePrompt(segments, opts.result_schema);
    }

    static getExtractedExecution(result: ConverseResponse, _prompt?: BedrockPrompt): CompletionChunkObject {
        return {
            result: result.output?.message?.content?.map(c => c.text).join("\n") ?? "",
            token_usage: {
                prompt: result.usage?.inputTokens,
                result: result.usage?.outputTokens,
                total: result.usage?.totalTokens,
            },
            finish_reason: converseFinishReason(result.stopReason),
        }
    };

    static getExtractedStream(result: ConverseStreamOutput, _prompt?: BedrockPrompt): CompletionChunkObject {
        let output: string = "";
        let stop_reason = "";
        let token_usage: ExecutionTokenUsage | undefined;
        if (result.contentBlockDelta) {
            output = result.contentBlockDelta.delta?.text ?? "";
        }
        if (result.messageStop) {
            stop_reason = result.messageStop.stopReason ?? "";
        }
        if (result.metadata) {
            token_usage = {
                prompt: result.metadata.usage?.inputTokens,
                result: result.metadata.usage?.outputTokens,
                total: result.metadata.usage?.totalTokens,
            }
        }
        return {
            result: output,
            token_usage: token_usage,
            finish_reason: converseFinishReason(stop_reason),
        }
    };

    async requestTextCompletion(prompt: ConverseRequest, options: ExecutionOptions): Promise<Completion> {
        let conversation = updateConversation(options.conversation as ConverseRequest, prompt);

        const payload = this.preparePayload(conversation, options);
        const executor = this.getExecutor();

        const res = await executor.converse({
            ...payload,
        });

        conversation = updateConversation(conversation, {
            messages: [res.output?.message ?? { content: [{ text: "" }], role: "assistant" }],
            modelId: prompt.modelId,
        });

        let tool_use: ToolUse[] | undefined = undefined;
        //Get tool requests
        if (res.stopReason == "tool_use") {
            tool_use = res.output?.message?.content?.reduce((tools: ToolUse[], c) => {
                if (c.toolUse) {
                    tools.push({
                        tool_name: c.toolUse.name ?? "",
                        tool_input: c.toolUse.input as any,
                        id: c.toolUse.toolUseId ?? "",
                    } satisfies ToolUse);
                }
                return tools;
            }, []);
            //If no tools were used, set to undefined
            if (tool_use && tool_use.length == 0) {
                tool_use = undefined;
            }
        }

        const completion = {
            ...BedrockDriver.getExtractedExecution(res, prompt),
            original_response: options.include_original_response ? res : undefined,
            conversation: conversation,
            tool_use: tool_use,
        };

        return completion;
    }

    extractRegion(modelString: string, defaultRegion: string): string {
        // Match region in full ARN pattern
        const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
        if (arnMatch) {
            return arnMatch[1];
        }

        // Match common AWS regions directly in string
        const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
        if (regionMatch) {
            return regionMatch[0];
        }

        return defaultRegion;
    }

    private async getCanStream(model: string, type: BedrockModelType): Promise<boolean> {
        let canStream: boolean = false;
        let error: any = null;
        const region = this.extractRegion(model, this.options.region);
        if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
            try {
                const response = await this.getService(region).getFoundationModel({
                    modelIdentifier: model
                });
                canStream = response.modelDetails?.responseStreamingSupported ?? false;
                return canStream;
            } catch (e) {
                error = e;
            }
        }
        if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
            try {
                const response = await this.getService(region).getInferenceProfile({
                    inferenceProfileIdentifier: model
                });
                canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
                return canStream;
            } catch (e) {
                error = e;
            }
        }
        if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
            try {
                const response = await this.getService(region).getCustomModel({
                    modelIdentifier: model
                });
                canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
                return canStream;
            } catch (e) {
                error = e;
            }
        }
        if (error) {
            console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
        }
        return canStream;
    }

    protected async canStream(options: ExecutionOptions): Promise<boolean> {
        let canStream = supportStreamingCache.get(options.model);
        if (canStream == null) {
            let type = BedrockModelType.Unknown;
            if (options.model.includes("foundation-model")) {
                type = BedrockModelType.FoundationModel;
            } else if (options.model.includes("inference-profile")) {
                type = BedrockModelType.InferenceProfile;
            } else if (options.model.includes("custom-model")) {
                type = BedrockModelType.CustomModel;
            }
            canStream = await this.getCanStream(options.model, type);
            supportStreamingCache.set(options.model, canStream);
        }
        return canStream;
    }

    async requestTextCompletionStream(prompt: ConverseRequest, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
        const payload = this.preparePayload(prompt, options);
        const executor = this.getExecutor();
        return executor.converseStream({
            ...payload,
        }).then((res) => {
            const stream = res.stream;

            if (!stream) {
                throw new Error("[Bedrock] Stream not found in response");
            }

            return transformAsyncIterator(stream, (stream: ConverseStreamOutput) => {
                //const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
                //console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
                return BedrockDriver.getExtractedStream(stream, prompt);
            });

        }).catch((err) => {
            this.logger.error("[Bedrock] Failed to stream", err);
            throw err;
        });
    }

    preparePayload(prompt: ConverseRequest, options: ExecutionOptions) {
        const model_options: TextFallbackOptions = options.model_options as TextFallbackOptions ?? { _option_id: "text-fallback" };

        let additionalField = {};

        if (options.model.includes("amazon")) {
            if (options.result_schema) {
                prompt.messages = converseJSONprefill(prompt.messages);
            }
            //Titan models also exists but does not support any additional options
            if (options.model.includes("nova")) {
                additionalField = { inferenceConfig: { topK: model_options.top_k } };
            }
        } else if (options.model.includes("claude")) {
            if (options.result_schema) {
                prompt.messages = converseJSONprefill(prompt.messages);
            }
            if (options.model.includes("claude-3-7")) {
                const thinking_options = options.model_options as BedrockClaudeOptions;
                const thinking = thinking_options.thinking_mode ?? false;
                additionalField = {
                    ...additionalField,
                    reasoning_config: {
                        type: thinking ? "enabled" : "disabled",
                        budget_tokens: thinking_options.thinking_budget_tokens,
                    }
                };
                if (thinking && (thinking_options.thinking_budget_tokens ?? 0) > 64000) {
                    additionalField = {
                        ...additionalField,
                        anthorpic_beta: ["output-128k-2025-02-19"]
                    };
                }
            }
            //Needs max_tokens to be set
            if (!model_options.max_tokens) {
                model_options.max_tokens = getMaxTokensLimit(options.model, model_options);
            }
            additionalField = { ...additionalField, top_k: model_options.top_k };
        } else if (options.model.includes("meta")) {
            //LLaMA models support no additional options
        } else if (options.model.includes("mistral")) {
            //7B instruct and 8x7B instruct
            if (options.model.includes("7b")) {
                additionalField = { top_k: model_options.top_k };
                //Does not support system messages
                if (prompt.system && prompt.system?.length != 0) {
                    prompt.messages?.push(converseSystemToMessages(prompt.system));
                    prompt.system = undefined;
                    prompt.messages = converseConcatMessages(prompt.messages);
                }
                if (options.result_schema) {
                    prompt.messages = converseJSONprefill(prompt.messages);
                }
            } else {
                //Other models such as Mistral Small,Large and Large 2
                //Support no additional fields.
            }
        } else if (options.model.includes("ai21")) {
            //Jamba models support no additional options
            //Jurassic 2 models do.
            if (options.model.includes("j2")) {
                additionalField = {
                    presencePenalty: { scale: model_options.presence_penalty },
                    frequencyPenalty: { scale: model_options.frequency_penalty },
                };
                //Does not support system messages
                if (prompt.system && prompt.system?.length != 0) {
                    prompt.messages?.push(converseSystemToMessages(prompt.system));
                    prompt.system = undefined;
                    prompt.messages = converseConcatMessages(prompt.messages);
                }
            }
        } else if (options.model.includes("cohere.command")) {
            // If last message is "```json", remove it.
            //Command R and R plus
            if (options.model.includes("cohere.command-r")) {
                additionalField = {
                    k: model_options.top_k,
                    frequency_penalty: model_options.frequency_penalty,
                    presence_penalty: model_options.presence_penalty,
                };
            } else {
                // Command non-R
                additionalField = { k: model_options.top_k };
                //Does not support system messages
                if (prompt.system && prompt.system?.length != 0) {
                    prompt.messages?.push(converseSystemToMessages(prompt.system));
                    prompt.system = undefined;
                    prompt.messages = converseConcatMessages(prompt.messages);
                }
            }
        } else if (options.model.includes("palmyra")) {
            const palmyraOptions = options.model_options as BedrockPalmyraOptions;
            additionalField = {
                seed: palmyraOptions?.seed,
                presence_penalty: palmyraOptions?.presence_penalty,
                frequency_penalty: palmyraOptions?.frequency_penalty,
                min_tokens: palmyraOptions?.min_tokens,
            }
        } else if (options.model.includes("deepseek")) {
            //DeepSeek models support no additional options
        }

        //If last message is "```json", add corresponding ``` as a stop sequence.
        if (prompt.messages && prompt.messages.length > 0) {
            if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
                let stopSeq = model_options.stop_sequence;
                if (!stopSeq) {
                    model_options.stop_sequence = ["```"];
                } else if (!stopSeq.includes("```")) {
                    stopSeq.push("```");
                    model_options.stop_sequence = stopSeq;
                }
            }
        }

        const tool_defs = getToolDefinitions(options.tools);

        const request: ConverseRequest = {
            messages: prompt.messages,
            system: prompt.system,
            modelId: options.model,
            inferenceConfig: {
                maxTokens: model_options.max_tokens,
                temperature: model_options.temperature,
                topP: model_options.top_p,
                stopSequences: model_options.stop_sequence,
            } satisfies InferenceConfiguration,
            additionalModelRequestFields: {
                ...additionalField,
            }
        };

        //Only add tools if they are defined
        if (tool_defs) {
            request.toolConfig = {
                tools: tool_defs,
            }
        }

        return request;
    }


    async requestImageGeneration(prompt: NovaMessagesPrompt, options: ExecutionOptions): Promise<Completion<ImageGeneration>> {
        if (options.output_modality !== Modalities.image) {
            throw new Error(`Image generation requires image output_modality`);
        }
        if (options.model_options?._option_id !== "bedrock-nova-canvas") {
            this.logger.warn("Invalid model options", { options: options.model_options });
        }
        const model_options = options.model_options as NovaCanvasOptions;

        const executor = this.getExecutor();
        const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;

        this.logger.info("Task type: " + taskType);

        if (typeof prompt === "string") {
            throw new Error("Bad prompt format");
        }

        const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);

        const res = await executor.invokeModel({
            modelId: options.model,
            contentType: "application/json",
            accept: "application/json",
            body: JSON.stringify(payload),
        },
            {
                requestTimeout: 60000 * 5
            });

        const decoder = new TextDecoder();
        const body = decoder.decode(res.body);
        const result = JSON.parse(body);

        return {
            error: result.error,
            result: {
                images: result.images,
            }
        }
    }

    async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {

        //convert options.params to Record<string, string>
        const params: Record<string, string> = {};
        for (const [key, value] of Object.entries(options.params || {})) {
            params[key] = String(value);
        }

        if (!this.options.training_bucket) {
            throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options")
        }

        const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
        const stream = await dataset.getStream();
        const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);

        const service = this.getService();
        const response = await service.send(new CreateModelCustomizationJobCommand({
            jobName: options.name + "-job",
            customModelName: options.name,
            roleArn: this.options.training_role_arn || undefined,
            baseModelIdentifier: options.model,
            clientRequestToken: "llumiverse-" + Date.now(),
            trainingDataConfig: {
                s3Uri: `s3://${upload.Bucket}/${upload.Key}`,
            },
            outputDataConfig: undefined,
            hyperParameters: params,
            //TODO not supported?
            //customizationType: "FINE_TUNING",
        }));

        const job = await service.send(new GetModelCustomizationJobCommand({
            jobIdentifier: response.jobArn
        }));

        return jobInfo(job, response.jobArn!);
    }

    async cancelTraining(jobId: string): Promise<TrainingJob> {
        const service = this.getService();
        await service.send(new StopModelCustomizationJobCommand({
            jobIdentifier: jobId
        }));
        const job = await service.send(new GetModelCustomizationJobCommand({
            jobIdentifier: jobId
        }));

        return jobInfo(job, jobId);
    }

    async getTrainingJob(jobId: string): Promise<TrainingJob> {
        const service = this.getService();
        const job = await service.send(new GetModelCustomizationJobCommand({
            jobIdentifier: jobId
        }));

        return jobInfo(job, jobId);
    }

    // ===================== management API ==================

    async validateConnection(): Promise<boolean> {
        const service = this.getService();
        this.logger.debug("[Bedrock] validating connection", service.config.credentials.name);
        //return true as if the client has been initialized, it means the connection is valid
        return true;
    }


    async listTrainableModels(): Promise<AIModel<string>[]> {
        this.logger.debug("[Bedrock] listing trainable models");
        return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false);
    }

    async listModels(): Promise<AIModel[]> {
        this.logger.debug("[Bedrock] listing models");
        // exclude trainable models since they are not executable
        // exclude embedding models, not to be used for typical completions.
        const filter = (m: FoundationModelSummary) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
        return this._listModels(filter);
    }

    async _listModels(foundationFilter?: (m: FoundationModelSummary) => boolean): Promise<AIModel[]> {
        const service = this.getService();
        const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([
            service.listFoundationModels({}).catch(() => {
                this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
                return undefined
            }),
            service.listCustomModels({}).catch(() => {
                this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
                return undefined
            }),
            service.listInferenceProfiles({}).catch(() => {
                this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
                return undefined
            }),
        ]);

        if (!foundationModelsList?.modelSummaries) {
            throw new Error("Foundation models not found");
        }

        let foundationModels = foundationModelsList.modelSummaries || [];
        if (foundationFilter) {
            foundationModels = foundationModels.filter(foundationFilter);
        }

        const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek", "writer"];
        const unsupportedModelsByPublisher = {
            amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"],
            anthropic: [],
            cohere: ["rerank"],
            ai21: [],
            mistral: [],
            meta: [],
            deepseek: [],
            writer: [],
        };

        // Helper function to check if model should be filtered out
        const shouldIncludeModel = (modelId?: string, providerName?: string): boolean => {
            if (!modelId || !providerName) return false;

            const normalizedProvider = providerName.toLowerCase();

            // Check if provider is supported
            const isProviderSupported = supportedPublishers.some(provider =>
                normalizedProvider.includes(provider)
            );

            if (!isProviderSupported) return false;

            // Check if model is in the unsupported list for its provider
            for (const provider of supportedPublishers) {
                if (normalizedProvider.includes(provider)) {
                    const unsupportedModels = unsupportedModelsByPublisher[provider as keyof typeof unsupportedModelsByPublisher] || [];
                    return !unsupportedModels.some(unsupported =>
                        modelId.toLowerCase().includes(unsupported)
                    );
                }
            }

            return true;
        };

        foundationModels = foundationModels.filter(m =>
            shouldIncludeModel(m.modelId, m.providerName)
        );

        const aiModels: AIModel[] = foundationModels.map((m) => {

            if (!m.modelId) {
                throw new Error("modelId not found");
            }

            const modelCapability = getModelCapabilities(m.modelArn ?? m.modelId, this.provider);

            const model: AIModel = {
                id: m.modelArn ?? m.modelId,
                name: `${m.providerName} ${m.modelName}`,
                provider: this.provider,
                //description: ``,
                owner: m.providerName,
                can_stream: m.responseStreamingSupported ?? false,
                input_modalities: m.inputModalities ? formatAmazonModalities(m.inputModalities) : modelModalitiesToArray(modelCapability.input),
                output_modalities: m.outputModalities ? formatAmazonModalities(m.outputModalities) : modelModalitiesToArray(modelCapability.input),
                tool_support: modelCapability.tool_support,
            };

            return model;
        });

        //add custom models
        if (customModelsList?.modelSummaries) {
            customModelsList.modelSummaries.forEach((m) => {

                if (!m.modelArn) {
                    throw new Error("Model ID not found");
                }

                const modelCapability = getModelCapabilities(m.modelArn, this.provider);

                const model: AIModel = {
                    id: m.modelArn,
                    name: m.modelName ?? m.modelArn,
                    provider: this.provider,
                    description: `Custom model from ${m.baseModelName}`,
                    is_custom: true,
                    input_modalities: modelModalitiesToArray(modelCapability.input),
                    output_modalities: modelModalitiesToArray(modelCapability.output),
                    tool_support: modelCapability.tool_support,
                };

                aiModels.push(model);
                this.validateConnection;
            });
        }

        //add inference profiles
        if (inferenceProfilesList?.inferenceProfileSummaries) {
            inferenceProfilesList.inferenceProfileSummaries.forEach((p) => {
                if (!p.inferenceProfileArn) {
                    throw new Error("Profile ARN not found");
                }

                // Apply the same filtering logic to inference profiles based on their name
                const profileId = p.inferenceProfileId || "";
                const profileName = p.inferenceProfileName || "";

                // Extract provider name from profile name or ID
                let providerName = "";
                for (const provider of supportedPublishers) {
                    if (profileName.toLowerCase().includes(provider) || profileId.toLowerCase().includes(provider)) {
                        providerName = provider;
                        break;
                    }
                }

                const modelCapability = getModelCapabilities(p.inferenceProfileArn ?? p.inferenceProfileId, this.provider);

                if (providerName && shouldIncludeModel(profileId, providerName)) {
                    const model: AIModel = {
                        id: p.inferenceProfileArn ?? p.inferenceProfileId,
                        name: p.inferenceProfileName ?? p.inferenceProfileArn,
                        provider: this.provider,
                        input_modalities: modelModalitiesToArray(modelCapability.input),
                        output_modalities: modelModalitiesToArray(modelCapability.output),
                        tool_support: modelCapability.tool_support,
                    };

                    aiModels.push(model);
                }
            });
        }

        return aiModels;
    }

    async generateEmbeddings({ text, image, model }: EmbeddingsOptions): Promise<EmbeddingsResult> {

        this.logger.info("[Bedrock] Generating embeddings with model " + model);
        const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
        const modelID = model ?? defaultModel;

        const invokeBody = {
            inputText: text,
            inputImage: image
        }

        const executor = this.getExecutor();
        const res = await executor.invokeModel(
            {
                modelId: modelID,
                contentType: "application/json",
                body: JSON.stringify(invokeBody),
            }
        );

        const decoder = new TextDecoder();
        const body = decoder.decode(res.body);

        const result = JSON.parse(body);

        if (!result.embedding) {
            throw new Error("Embeddings not found");
        }

        return {
            values: result.embedding,
            model: modelID,
            token_count: result.inputTextTokenCount
        };
    }
}

function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
    const jobStatus = job.status;
    let status = TrainingJobStatus.running;
    let details: string | undefined;
    if (jobStatus === ModelCustomizationJobStatus.COMPLETED) {
        status = TrainingJobStatus.succeeded;
    } else if (jobStatus === ModelCustomizationJobStatus.FAILED) {
        status = TrainingJobStatus.failed;
        details = job.failureMessage || "error";
    } else if (jobStatus === ModelCustomizationJobStatus.STOPPED) {
        status = TrainingJobStatus.cancelled;
    } else {
        status = TrainingJobStatus.running;
        details = jobStatus;
    }
    job.baseModelArn
    return {
        id: jobId,
        model: job.outputModelArn,
        status,
        details
    }
}

function getToolDefinitions(tools?: ToolDefinition[]): Tool[] | undefined {
    return tools ? tools.map(getToolDefinition) : undefined;
}

function getToolDefinition(tool: ToolDefinition): Tool.ToolSpecMember {
    return {
        toolSpec: {
            name: tool.name,
            description: tool.description,
            inputSchema: {
                json: tool.input_schema as any,
            }
        }
    }
}

/**
 * Update the conversation messages
 * @param prompt
 * @param response
 * @returns
 */
function updateConversation(conversation: ConverseRequest, prompt: ConverseRequest): ConverseRequest {
    return {
        ...conversation,
        ...prompt,
        messages: [...(conversation?.messages || []), ...(prompt.messages || [])],
        system: prompt.system || conversation?.system,
    };
}

function formatAmazonModalities(modalities: ModelModality[]): string[] {
    const standardizedModalities: string[] = [];
    for (const modality of modalities) {
        if (modality === ModelModality.TEXT) {
            standardizedModalities.push("text");
        } else if (modality === ModelModality.IMAGE) {
            standardizedModalities.push("image");
        } else if (modality === ModelModality.EMBEDDING) {
            standardizedModalities.push("embedding");
        } else if (modality == "SPEECH") {
            standardizedModalities.push("audio");
        } else if (modality == "VIDEO") {
            standardizedModalities.push("video");
        } else {
            // Handle other modalities as needed
            standardizedModalities.push((modality as string).toString().toLowerCase());
        }
    }
    return standardizedModalities;
}