1 | import {
|
2 | PredictionsOptions,
|
3 | TranslateTextInput,
|
4 | TranslateTextOutput,
|
5 | TextToSpeechInput,
|
6 | ProviderOptions,
|
7 | TextToSpeechOutput,
|
8 | SpeechToTextInput,
|
9 | SpeechToTextOutput,
|
10 | IdentifyTextInput,
|
11 | IdentifyTextOutput,
|
12 | IdentifyLabelsOutput,
|
13 | IdentifyLabelsInput,
|
14 | IdentifyEntitiesInput,
|
15 | IdentifyEntitiesOutput,
|
16 | InterpretTextOutput,
|
17 | InterpretTextInput,
|
18 | } from './types';
|
19 | import {
|
20 | AbstractConvertPredictionsProvider,
|
21 | AbstractIdentifyPredictionsProvider,
|
22 | AbstractInterpretPredictionsProvider,
|
23 | AbstractPredictionsProvider,
|
24 | } from './types/Providers';
|
25 | import { Amplify, ConsoleLogger as Logger } from '@aws-amplify/core';
|
26 |
|
27 | const logger = new Logger('Predictions');
|
28 |
|
29 | export class PredictionsClass {
|
30 | private _options: PredictionsOptions;
|
31 |
|
32 | private _convertPluggables: AbstractConvertPredictionsProvider[];
|
33 | private _identifyPluggables: AbstractIdentifyPredictionsProvider[];
|
34 | private _interpretPluggables: AbstractInterpretPredictionsProvider[];
|
35 |
|
36 | |
37 |
|
38 |
|
39 |
|
40 | constructor(options: PredictionsOptions) {
|
41 | this._options = options;
|
42 | this._convertPluggables = [];
|
43 | this._identifyPluggables = [];
|
44 | this._interpretPluggables = [];
|
45 | }
|
46 |
|
47 | public getModuleName() {
|
48 | return 'Predictions';
|
49 | }
|
50 |
|
51 | |
52 |
|
53 |
|
54 |
|
55 | public addPluggable(pluggable: AbstractPredictionsProvider) {
|
56 | if (this.getPluggable(pluggable.getProviderName())) {
|
57 | throw new Error(
|
58 | `Pluggable with name ${pluggable.getProviderName()} has already been added.`
|
59 | );
|
60 | }
|
61 | let pluggableAdded: boolean = false;
|
62 | if (this.implementsConvertPluggable(pluggable)) {
|
63 | this._convertPluggables.push(pluggable);
|
64 | pluggableAdded = true;
|
65 | }
|
66 | if (this.implementsIdentifyPluggable(pluggable)) {
|
67 | this._identifyPluggables.push(pluggable);
|
68 | pluggableAdded = true;
|
69 | }
|
70 | if (this.implementsInterpretPluggable(pluggable)) {
|
71 | this._interpretPluggables.push(pluggable);
|
72 | pluggableAdded = true;
|
73 | }
|
74 | if (pluggableAdded) {
|
75 | this.configurePluggable(pluggable);
|
76 | }
|
77 | }
|
78 |
|
79 | |
80 |
|
81 |
|
82 |
|
83 | public getPluggable(providerName: string): AbstractPredictionsProvider {
|
84 | const pluggable = this.getAllProviders().find(
|
85 | pluggable => pluggable.getProviderName() === providerName
|
86 | );
|
87 | if (pluggable === undefined) {
|
88 | logger.debug('No plugin found with providerName=>', providerName);
|
89 | return null;
|
90 | } else return pluggable;
|
91 | }
|
92 |
|
93 | |
94 |
|
95 |
|
96 |
|
97 | public removePluggable(providerName: string) {
|
98 | this._convertPluggables = this._convertPluggables.filter(
|
99 | pluggable => pluggable.getProviderName() !== providerName
|
100 | );
|
101 | this._identifyPluggables = this._identifyPluggables.filter(
|
102 | pluggable => pluggable.getProviderName() !== providerName
|
103 | );
|
104 | this._interpretPluggables = this._interpretPluggables.filter(
|
105 | pluggable => pluggable.getProviderName() !== providerName
|
106 | );
|
107 | return;
|
108 | }
|
109 |
|
110 | |
111 |
|
112 |
|
113 |
|
114 |
|
115 | configure(options: PredictionsOptions) {
|
116 | let predictionsConfig = options ? options.predictions || options : {};
|
117 | predictionsConfig = { ...predictionsConfig, ...options };
|
118 | this._options = Object.assign({}, this._options, predictionsConfig);
|
119 | logger.debug('configure Predictions', this._options);
|
120 | this.getAllProviders().forEach(pluggable =>
|
121 | this.configurePluggable(pluggable)
|
122 | );
|
123 | }
|
124 |
|
125 | public interpret(
|
126 | input: InterpretTextInput,
|
127 | options?: ProviderOptions
|
128 | ): Promise<InterpretTextOutput>;
|
129 | public interpret(
|
130 | input: InterpretTextInput,
|
131 | options?: ProviderOptions
|
132 | ): Promise<InterpretTextOutput> {
|
133 | const pluggableToExecute = this.getPluggableToExecute(
|
134 | this._interpretPluggables,
|
135 | options
|
136 | );
|
137 | return pluggableToExecute.interpret(input);
|
138 | }
|
139 |
|
140 | public convert(
|
141 | input: TranslateTextInput,
|
142 | options?: ProviderOptions
|
143 | ): Promise<TranslateTextOutput>;
|
144 | public convert(
|
145 | input: TextToSpeechInput,
|
146 | options?: ProviderOptions
|
147 | ): Promise<TextToSpeechOutput>;
|
148 | public convert(
|
149 | input: SpeechToTextInput,
|
150 | options?: ProviderOptions
|
151 | ): Promise<SpeechToTextOutput>;
|
152 | public convert(
|
153 | input: TranslateTextInput | TextToSpeechInput | SpeechToTextInput,
|
154 | options?: ProviderOptions
|
155 | ): Promise<TranslateTextOutput | TextToSpeechOutput | SpeechToTextOutput> {
|
156 | const pluggableToExecute = this.getPluggableToExecute(
|
157 | this._convertPluggables,
|
158 | options
|
159 | );
|
160 | return pluggableToExecute.convert(input);
|
161 | }
|
162 |
|
163 | public identify(
|
164 | input: IdentifyTextInput,
|
165 | options?: ProviderOptions
|
166 | ): Promise<IdentifyTextOutput>;
|
167 | public identify(
|
168 | input: IdentifyLabelsInput,
|
169 | options?: ProviderOptions
|
170 | ): Promise<IdentifyLabelsOutput>;
|
171 | public identify(
|
172 | input: IdentifyEntitiesInput,
|
173 | options?: ProviderOptions
|
174 | ): Promise<IdentifyEntitiesOutput>;
|
175 | public identify(
|
176 | input: IdentifyTextInput | IdentifyLabelsInput | IdentifyEntitiesInput,
|
177 | options: ProviderOptions
|
178 | ): Promise<
|
179 | IdentifyTextOutput | IdentifyLabelsOutput | IdentifyEntitiesOutput
|
180 | > {
|
181 | const pluggableToExecute = this.getPluggableToExecute(
|
182 | this._identifyPluggables,
|
183 | options
|
184 | );
|
185 | return pluggableToExecute.identify(input);
|
186 | }
|
187 |
|
188 |
|
189 | private getPluggableToExecute<T extends AbstractPredictionsProvider>(
|
190 | pluggables: T[],
|
191 | providerOptions: ProviderOptions
|
192 | ): T {
|
193 |
|
194 |
|
195 | if (providerOptions && providerOptions.providerName) {
|
196 | return [...pluggables].find(
|
197 | pluggable =>
|
198 | pluggable.getProviderName() === providerOptions.providerName
|
199 | );
|
200 | } else {
|
201 | if (pluggables.length === 1) {
|
202 | return pluggables[0];
|
203 | } else {
|
204 | throw new Error(
|
205 | 'More than one or no providers are configured, ' +
|
206 | 'Either specify a provider name or configure exactly one provider'
|
207 | );
|
208 | }
|
209 | }
|
210 | }
|
211 |
|
212 | private getAllProviders() {
|
213 | return [
|
214 | ...this._convertPluggables,
|
215 | ...this._identifyPluggables,
|
216 | ...this._interpretPluggables,
|
217 | ];
|
218 | }
|
219 |
|
220 | private configurePluggable(pluggable: AbstractPredictionsProvider) {
|
221 | const categoryConfig = Object.assign(
|
222 | {},
|
223 | this._options['predictions'],
|
224 | this._options[pluggable.getCategory().toLowerCase()]
|
225 | );
|
226 | pluggable.configure(categoryConfig);
|
227 | }
|
228 |
|
229 | private implementsConvertPluggable(
|
230 | obj: any
|
231 | ): obj is AbstractConvertPredictionsProvider {
|
232 | return obj && typeof obj.convert === 'function';
|
233 | }
|
234 |
|
235 | private implementsIdentifyPluggable(
|
236 | obj: any
|
237 | ): obj is AbstractIdentifyPredictionsProvider {
|
238 | return obj && typeof obj.identify === 'function';
|
239 | }
|
240 |
|
241 | private implementsInterpretPluggable(
|
242 | obj: any
|
243 | ): obj is AbstractInterpretPredictionsProvider {
|
244 | return obj && typeof obj.interpret === 'function';
|
245 | }
|
246 | }
|
247 |
|
248 | export const Predictions = new PredictionsClass({});
|
249 | Amplify.register(Predictions);
|