UNPKG

7.39 kBPlain TextView Raw
1import {
2 PredictionsOptions,
3 TranslateTextInput,
4 TranslateTextOutput,
5 TextToSpeechInput,
6 ProviderOptions,
7 TextToSpeechOutput,
8 SpeechToTextInput,
9 SpeechToTextOutput,
10 IdentifyTextInput,
11 IdentifyTextOutput,
12 IdentifyLabelsOutput,
13 IdentifyLabelsInput,
14 IdentifyEntitiesInput,
15 IdentifyEntitiesOutput,
16 InterpretTextOutput,
17 InterpretTextInput,
18} from './types';
19import {
20 AbstractConvertPredictionsProvider,
21 AbstractIdentifyPredictionsProvider,
22 AbstractInterpretPredictionsProvider,
23 AbstractPredictionsProvider,
24} from './types/Providers';
25import { Amplify, ConsoleLogger as Logger } from '@aws-amplify/core';
26
27const logger = new Logger('Predictions');
28
29export class PredictionsClass {
30 private _options: PredictionsOptions;
31
32 private _convertPluggables: AbstractConvertPredictionsProvider[];
33 private _identifyPluggables: AbstractIdentifyPredictionsProvider[];
34 private _interpretPluggables: AbstractInterpretPredictionsProvider[];
35
36 /**
37 * Initialize Predictions with AWS configurations
38 * @param {PredictionsOptions} options - Configuration object for Predictions
39 */
40 constructor(options: PredictionsOptions) {
41 this._options = options;
42 this._convertPluggables = [];
43 this._identifyPluggables = [];
44 this._interpretPluggables = [];
45 }
46
47 public getModuleName() {
48 return 'Predictions';
49 }
50
51 /**
52 * add plugin/pluggable into Predictions category
53 * @param {Object} pluggable - an instance of the plugin/pluggable
54 **/
55 public addPluggable(pluggable: AbstractPredictionsProvider) {
56 if (this.getPluggable(pluggable.getProviderName())) {
57 throw new Error(
58 `Pluggable with name ${pluggable.getProviderName()} has already been added.`
59 );
60 }
61 let pluggableAdded: boolean = false;
62 if (this.implementsConvertPluggable(pluggable)) {
63 this._convertPluggables.push(pluggable);
64 pluggableAdded = true;
65 }
66 if (this.implementsIdentifyPluggable(pluggable)) {
67 this._identifyPluggables.push(pluggable);
68 pluggableAdded = true;
69 }
70 if (this.implementsInterpretPluggable(pluggable)) {
71 this._interpretPluggables.push(pluggable);
72 pluggableAdded = true;
73 }
74 if (pluggableAdded) {
75 this.configurePluggable(pluggable);
76 }
77 }
78
79 /**
80 * Get the plugin object
81 * @param providerName - the name of the plugin
82 */
83 public getPluggable(providerName: string): AbstractPredictionsProvider {
84 const pluggable = this.getAllProviders().find(
85 pluggable => pluggable.getProviderName() === providerName
86 );
87 if (pluggable === undefined) {
88 logger.debug('No plugin found with providerName=>', providerName);
89 return null;
90 } else return pluggable;
91 }
92
93 /**
94 * Remove the plugin object
95 * @param providerName - the name of the plugin
96 */
97 public removePluggable(providerName: string) {
98 this._convertPluggables = this._convertPluggables.filter(
99 pluggable => pluggable.getProviderName() !== providerName
100 );
101 this._identifyPluggables = this._identifyPluggables.filter(
102 pluggable => pluggable.getProviderName() !== providerName
103 );
104 this._interpretPluggables = this._interpretPluggables.filter(
105 pluggable => pluggable.getProviderName() !== providerName
106 );
107 return;
108 }
109
110 /**
111 * To make both top level providers and category level providers work with same interface and configuration
112 * this method duplicates Predictions config into parent level config (for top level provider) and
113 * category level config (such as convert, identify etc) and pass both to each provider.
114 */
115 configure(options: PredictionsOptions) {
116 let predictionsConfig = options ? options.predictions || options : {};
117 predictionsConfig = { ...predictionsConfig, ...options };
118 this._options = Object.assign({}, this._options, predictionsConfig);
119 logger.debug('configure Predictions', this._options);
120 this.getAllProviders().forEach(pluggable =>
121 this.configurePluggable(pluggable)
122 );
123 }
124
125 public interpret(
126 input: InterpretTextInput,
127 options?: ProviderOptions
128 ): Promise<InterpretTextOutput>;
129 public interpret(
130 input: InterpretTextInput,
131 options?: ProviderOptions
132 ): Promise<InterpretTextOutput> {
133 const pluggableToExecute = this.getPluggableToExecute(
134 this._interpretPluggables,
135 options
136 );
137 return pluggableToExecute.interpret(input);
138 }
139
140 public convert(
141 input: TranslateTextInput,
142 options?: ProviderOptions
143 ): Promise<TranslateTextOutput>;
144 public convert(
145 input: TextToSpeechInput,
146 options?: ProviderOptions
147 ): Promise<TextToSpeechOutput>;
148 public convert(
149 input: SpeechToTextInput,
150 options?: ProviderOptions
151 ): Promise<SpeechToTextOutput>;
152 public convert(
153 input: TranslateTextInput | TextToSpeechInput | SpeechToTextInput,
154 options?: ProviderOptions
155 ): Promise<TranslateTextOutput | TextToSpeechOutput | SpeechToTextOutput> {
156 const pluggableToExecute = this.getPluggableToExecute(
157 this._convertPluggables,
158 options
159 );
160 return pluggableToExecute.convert(input);
161 }
162
163 public identify(
164 input: IdentifyTextInput,
165 options?: ProviderOptions
166 ): Promise<IdentifyTextOutput>;
167 public identify(
168 input: IdentifyLabelsInput,
169 options?: ProviderOptions
170 ): Promise<IdentifyLabelsOutput>;
171 public identify(
172 input: IdentifyEntitiesInput,
173 options?: ProviderOptions
174 ): Promise<IdentifyEntitiesOutput>;
175 public identify(
176 input: IdentifyTextInput | IdentifyLabelsInput | IdentifyEntitiesInput,
177 options: ProviderOptions
178 ): Promise<
179 IdentifyTextOutput | IdentifyLabelsOutput | IdentifyEntitiesOutput
180 > {
181 const pluggableToExecute = this.getPluggableToExecute(
182 this._identifyPluggables,
183 options
184 );
185 return pluggableToExecute.identify(input);
186 }
187
188 // tslint:disable-next-line: max-line-length
189 private getPluggableToExecute<T extends AbstractPredictionsProvider>(
190 pluggables: T[],
191 providerOptions: ProviderOptions
192 ): T {
193 // Give preference to provider name first since it is more specific to this call, even if
194 // there is only one provider configured to error out if the name provided is not the one matched.
195 if (providerOptions && providerOptions.providerName) {
196 return [...pluggables].find(
197 pluggable =>
198 pluggable.getProviderName() === providerOptions.providerName
199 );
200 } else {
201 if (pluggables.length === 1) {
202 return pluggables[0];
203 } else {
204 throw new Error(
205 'More than one or no providers are configured, ' +
206 'Either specify a provider name or configure exactly one provider'
207 );
208 }
209 }
210 }
211
212 private getAllProviders() {
213 return [
214 ...this._convertPluggables,
215 ...this._identifyPluggables,
216 ...this._interpretPluggables,
217 ];
218 }
219
220 private configurePluggable(pluggable: AbstractPredictionsProvider) {
221 const categoryConfig = Object.assign(
222 {},
223 this._options['predictions'], // Parent predictions config for the top level provider
224 this._options[pluggable.getCategory().toLowerCase()] // Actual category level config
225 );
226 pluggable.configure(categoryConfig);
227 }
228
229 private implementsConvertPluggable(
230 obj: any
231 ): obj is AbstractConvertPredictionsProvider {
232 return obj && typeof obj.convert === 'function';
233 }
234
235 private implementsIdentifyPluggable(
236 obj: any
237 ): obj is AbstractIdentifyPredictionsProvider {
238 return obj && typeof obj.identify === 'function';
239 }
240
241 private implementsInterpretPluggable(
242 obj: any
243 ): obj is AbstractInterpretPredictionsProvider {
244 return obj && typeof obj.interpret === 'function';
245 }
246}
247
248export const Predictions = new PredictionsClass({});
249Amplify.register(Predictions);