{"version":3,"file":"sagemaker_endpoint.cjs","names":["LLM","SageMakerRuntimeClient","InvokeEndpointCommand","InvokeEndpointWithResponseStreamCommand","GenerationChunk"],"sources":["../../src/llms/sagemaker_endpoint.ts"],"sourcesContent":["import {\n  InvokeEndpointCommand,\n  InvokeEndpointWithResponseStreamCommand,\n  SageMakerRuntimeClient,\n  SageMakerRuntimeClientConfig,\n} from \"@aws-sdk/client-sagemaker-runtime\";\nimport { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\nimport { GenerationChunk } from \"@langchain/core/outputs\";\nimport {\n  type BaseLLMCallOptions,\n  type BaseLLMParams,\n  LLM,\n} from \"@langchain/core/language_models/llms\";\n\n/**\n * A handler class to transform input from LLM to a format that SageMaker\n * endpoint expects. Similarily, the class also handles transforming output from\n * the SageMaker endpoint to a format that LLM class expects.\n *\n * Example:\n * ```\n * class ContentHandler implements ContentHandlerBase<string, string> {\n *   contentType = \"application/json\"\n *   accepts = \"application/json\"\n *\n *   transformInput(prompt: string, modelKwargs: Record<string, unknown>) {\n *     const inputString = JSON.stringify({\n *       prompt,\n *      ...modelKwargs\n *     })\n *     return Buffer.from(inputString)\n *   }\n *\n *   transformOutput(output: Uint8Array) {\n *     const responseJson = JSON.parse(Buffer.from(output).toString(\"utf-8\"))\n *     return responseJson[0].generated_text\n *   }\n *\n * }\n * ```\n */\nexport abstract class BaseSageMakerContentHandler<InputType, OutputType> {\n  contentType = \"text/plain\";\n\n  accepts = \"text/plain\";\n\n  /**\n   * Transforms the prompt and model arguments into a specific format for sending to SageMaker.\n   * @param {InputType} prompt The prompt to be transformed.\n   * @param {Record<string, unknown>} modelKwargs Additional arguments.\n   * @returns {Promise<Uint8Array>} A promise that resolves to the formatted data for sending.\n   */\n  abstract transformInput(\n    prompt: InputType,\n    modelKwargs: Record<string, unknown>\n  ): Promise<Uint8Array>;\n\n  /**\n   * Transforms SageMaker output into a desired format.\n   * @param {Uint8Array} output The raw output from SageMaker.\n   * @returns {Promise<OutputType>} A promise that resolves to the transformed data.\n   */\n  abstract transformOutput(output: Uint8Array): Promise<OutputType>;\n}\n\nexport type SageMakerLLMContentHandler = BaseSageMakerContentHandler<\n  string,\n  string\n>;\n\n/**\n * The SageMakerEndpointInput interface defines the input parameters for\n * the SageMakerEndpoint class, which includes the endpoint name, client\n * options for the SageMaker client, the content handler, and optional\n * keyword arguments for the model and the endpoint.\n */\nexport interface SageMakerEndpointInput extends BaseLLMParams {\n  /**\n   * The name of the endpoint from the deployed SageMaker model. Must be unique\n   * within an AWS Region.\n   */\n  endpointName: string;\n  /**\n   * Options passed to the SageMaker client.\n   */\n  clientOptions: SageMakerRuntimeClientConfig;\n  /**\n   * Key word arguments to pass to the model.\n   */\n  modelKwargs?: Record<string, unknown>;\n  /**\n   * Optional attributes passed to the InvokeEndpointCommand\n   */\n  endpointKwargs?: Record<string, unknown>;\n  /**\n   * The content handler class that provides an input and output transform\n   * functions to handle formats between LLM and the endpoint.\n   */\n  contentHandler: SageMakerLLMContentHandler;\n  streaming?: boolean;\n}\n\n/**\n * The SageMakerEndpoint class is used to interact with SageMaker\n * Inference Endpoint models. It uses the AWS client for authentication,\n * which automatically loads credentials.\n * If a specific credential profile is to be used, the name of the profile\n * from the ~/.aws/credentials file must be passed. The credentials or\n * roles used should have the required policies to access the SageMaker\n * endpoint.\n */\nexport class SageMakerEndpoint extends LLM<BaseLLMCallOptions> {\n  lc_serializable = true;\n\n  static lc_name() {\n    return \"SageMakerEndpoint\";\n  }\n\n  get lc_secrets(): { [key: string]: string } | undefined {\n    return {\n      \"clientOptions.credentials.accessKeyId\": \"AWS_ACCESS_KEY_ID\",\n      \"clientOptions.credentials.secretAccessKey\": \"AWS_SECRET_ACCESS_KEY\",\n      \"clientOptions.credentials.sessionToken\": \"AWS_SESSION_TOKEN\",\n    };\n  }\n\n  endpointName: string;\n\n  modelKwargs?: Record<string, unknown>;\n\n  endpointKwargs?: Record<string, unknown>;\n\n  client: SageMakerRuntimeClient;\n\n  contentHandler: SageMakerLLMContentHandler;\n\n  streaming: boolean;\n\n  constructor(fields: SageMakerEndpointInput) {\n    super(fields);\n\n    if (!fields.clientOptions.region) {\n      throw new Error(\n        `Please pass a \"clientOptions\" object with a \"region\" field to the constructor`\n      );\n    }\n\n    const endpointName = fields?.endpointName;\n    if (!endpointName) {\n      throw new Error(`Please pass an \"endpointName\" field to the constructor`);\n    }\n\n    const contentHandler = fields?.contentHandler;\n    if (!contentHandler) {\n      throw new Error(\n        `Please pass a \"contentHandler\" field to the constructor`\n      );\n    }\n\n    this.endpointName = fields.endpointName;\n    this.contentHandler = fields.contentHandler;\n    this.endpointKwargs = fields.endpointKwargs;\n    this.modelKwargs = fields.modelKwargs;\n    this.streaming = fields.streaming ?? false;\n    this.client = new SageMakerRuntimeClient(fields.clientOptions);\n  }\n\n  _llmType() {\n    return \"sagemaker_endpoint\";\n  }\n\n  /**\n   * Calls the SageMaker endpoint and retrieves the result.\n   * @param {string} prompt The input prompt.\n   * @param {this[\"ParsedCallOptions\"]} options Parsed call options.\n   * @param {CallbackManagerForLLMRun} runManager Optional run manager.\n   * @returns {Promise<string>} A promise that resolves to the generated string.\n   */\n  /** @ignore */\n  async _call(\n    prompt: string,\n    options: this[\"ParsedCallOptions\"],\n    runManager?: CallbackManagerForLLMRun\n  ): Promise<string> {\n    return this.streaming\n      ? await this.streamingCall(prompt, options, runManager)\n      : await this.noStreamingCall(prompt, options);\n  }\n\n  private async streamingCall(\n    prompt: string,\n    options: this[\"ParsedCallOptions\"],\n    runManager?: CallbackManagerForLLMRun\n  ): Promise<string> {\n    const chunks = [];\n    for await (const chunk of this._streamResponseChunks(\n      prompt,\n      options,\n      runManager\n    )) {\n      chunks.push(chunk.text);\n    }\n    return chunks.join(\"\");\n  }\n\n  private async noStreamingCall(\n    prompt: string,\n    options: this[\"ParsedCallOptions\"]\n  ): Promise<string> {\n    const body = await this.contentHandler.transformInput(\n      prompt,\n      this.modelKwargs ?? {}\n    );\n    const { contentType, accepts } = this.contentHandler;\n\n    const response = await this.caller.call(() =>\n      this.client.send(\n        new InvokeEndpointCommand({\n          EndpointName: this.endpointName,\n          Body: body,\n          ContentType: contentType,\n          Accept: accepts,\n          ...this.endpointKwargs,\n        }),\n        { abortSignal: options.signal }\n      )\n    );\n\n    if (response.Body === undefined) {\n      throw new Error(\"Inference result missing Body\");\n    }\n    return this.contentHandler.transformOutput(response.Body);\n  }\n\n  /**\n   * Streams response chunks from the SageMaker endpoint.\n   * @param {string} prompt The input prompt.\n   * @param {this[\"ParsedCallOptions\"]} options Parsed call options.\n   * @returns {AsyncGenerator<GenerationChunk>} An asynchronous generator yielding generation chunks.\n   */\n  async *_streamResponseChunks(\n    prompt: string,\n    options: this[\"ParsedCallOptions\"],\n    runManager?: CallbackManagerForLLMRun\n  ): AsyncGenerator<GenerationChunk> {\n    const body = await this.contentHandler.transformInput(\n      prompt,\n      this.modelKwargs ?? {}\n    );\n    const { contentType, accepts } = this.contentHandler;\n\n    const stream = await this.caller.call(() =>\n      this.client.send(\n        new InvokeEndpointWithResponseStreamCommand({\n          EndpointName: this.endpointName,\n          Body: body,\n          ContentType: contentType,\n          Accept: accepts,\n          ...this.endpointKwargs,\n        }),\n        { abortSignal: options.signal }\n      )\n    );\n\n    if (!stream.Body) {\n      throw new Error(\"Inference result missing Body\");\n    }\n\n    for await (const chunk of stream.Body) {\n      if (chunk.PayloadPart && chunk.PayloadPart.Bytes) {\n        const text = await this.contentHandler.transformOutput(\n          chunk.PayloadPart.Bytes\n        );\n        yield new GenerationChunk({\n          text,\n          generationInfo: {\n            ...chunk,\n            response: undefined,\n          },\n        });\n        await runManager?.handleLLMNewToken(text);\n      } else if (chunk.InternalStreamFailure) {\n        throw new Error(chunk.InternalStreamFailure.message);\n      } else if (chunk.ModelStreamError) {\n        throw new Error(chunk.ModelStreamError.message);\n      }\n    }\n  }\n}\n"],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAyCA,IAAsB,8BAAtB,MAAyE;CACvE,cAAc;CAEd,UAAU;;;;;;;;;;;AAmEZ,IAAa,oBAAb,cAAuCA,qCAAAA,IAAwB;CAC7D,kBAAkB;CAElB,OAAO,UAAU;AACf,SAAO;;CAGT,IAAI,aAAoD;AACtD,SAAO;GACL,yCAAyC;GACzC,6CAA6C;GAC7C,0CAA0C;GAC3C;;CAGH;CAEA;CAEA;CAEA;CAEA;CAEA;CAEA,YAAY,QAAgC;AAC1C,QAAM,OAAO;AAEb,MAAI,CAAC,OAAO,cAAc,OACxB,OAAM,IAAI,MACR,gFACD;AAIH,MAAI,CADiB,QAAQ,aAE3B,OAAM,IAAI,MAAM,yDAAyD;AAI3E,MAAI,CADmB,QAAQ,eAE7B,OAAM,IAAI,MACR,0DACD;AAGH,OAAK,eAAe,OAAO;AAC3B,OAAK,iBAAiB,OAAO;AAC7B,OAAK,iBAAiB,OAAO;AAC7B,OAAK,cAAc,OAAO;AAC1B,OAAK,YAAY,OAAO,aAAa;AACrC,OAAK,SAAS,IAAIC,kCAAAA,uBAAuB,OAAO,cAAc;;CAGhE,WAAW;AACT,SAAO;;;;;;;;;;CAWT,MAAM,MACJ,QACA,SACA,YACiB;AACjB,SAAO,KAAK,YACR,MAAM,KAAK,cAAc,QAAQ,SAAS,WAAW,GACrD,MAAM,KAAK,gBAAgB,QAAQ,QAAQ;;CAGjD,MAAc,cACZ,QACA,SACA,YACiB;EACjB,MAAM,SAAS,EAAE;AACjB,aAAW,MAAM,SAAS,KAAK,sBAC7B,QACA,SACA,WACD,CACC,QAAO,KAAK,MAAM,KAAK;AAEzB,SAAO,OAAO,KAAK,GAAG;;CAGxB,MAAc,gBACZ,QACA,SACiB;EACjB,MAAM,OAAO,MAAM,KAAK,eAAe,eACrC,QACA,KAAK,eAAe,EAAE,CACvB;EACD,MAAM,EAAE,aAAa,YAAY,KAAK;EAEtC,MAAM,WAAW,MAAM,KAAK,OAAO,WACjC,KAAK,OAAO,KACV,IAAIC,kCAAAA,sBAAsB;GACxB,cAAc,KAAK;GACnB,MAAM;GACN,aAAa;GACb,QAAQ;GACR,GAAG,KAAK;GACT,CAAC,EACF,EAAE,aAAa,QAAQ,QAAQ,CAChC,CACF;AAED,MAAI,SAAS,SAAS,KAAA,EACpB,OAAM,IAAI,MAAM,gCAAgC;AAElD,SAAO,KAAK,eAAe,gBAAgB,SAAS,KAAK;;;;;;;;CAS3D,OAAO,sBACL,QACA,SACA,YACiC;EACjC,MAAM,OAAO,MAAM,KAAK,eAAe,eACrC,QACA,KAAK,eAAe,EAAE,CACvB;EACD,MAAM,EAAE,aAAa,YAAY,KAAK;EAEtC,MAAM,SAAS,MAAM,KAAK,OAAO,WAC/B,KAAK,OAAO,KACV,IAAIC,kCAAAA,wCAAwC;GAC1C,cAAc,KAAK;GACnB,MAAM;GACN,aAAa;GACb,QAAQ;GACR,GAAG,KAAK;GACT,CAAC,EACF,EAAE,aAAa,QAAQ,QAAQ,CAChC,CACF;AAED,MAAI,CAAC,OAAO,KACV,OAAM,IAAI,MAAM,gCAAgC;AAGlD,aAAW,MAAM,SAAS,OAAO,KAC/B,KAAI,MAAM,eAAe,MAAM,YAAY,OAAO;GAChD,MAAM,OAAO,MAAM,KAAK,eAAe,gBACrC,MAAM,YAAY,MACnB;AACD,SAAM,IAAIC,wBAAAA,gBAAgB;IACxB;IACA,gBAAgB;KACd,GAAG;KACH,UAAU,KAAA;KACX;IACF,CAAC;AACF,SAAM,YAAY,kBAAkB,KAAK;aAChC,MAAM,sBACf,OAAM,IAAI,MAAM,MAAM,sBAAsB,QAAQ;WAC3C,MAAM,iBACf,OAAM,IAAI,MAAM,MAAM,iBAAiB,QAAQ"}