1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /**
|
11 | * Optimizers.
|
12 | */
|
13 | import { train } from '@tensorflow/tfjs-core';
|
14 | import { epsilon } from './backend/common';
|
15 | import { ValueError } from './errors';
|
16 | // Add (de)serialize()
|
17 | // Porting note: This diverges from the PyKeras implementation and may need to
|
18 | // change based on (de)serialization requirements.
|
19 | export function getOptimizer(identifier) {
|
20 | const optimizerMap = {
|
21 | 'Adagrad': () => train.adagrad(0.01),
|
22 | 'Adadelta': () => train.adadelta(1, 0.95, epsilon()),
|
23 | 'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()),
|
24 | 'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0),
|
25 | 'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()),
|
26 | 'SGD': () => train.sgd(0.01)
|
27 | };
|
28 | optimizerMap['adagrad'] = optimizerMap['Adagrad'];
|
29 | optimizerMap['adadelta'] = optimizerMap['Adadelta'];
|
30 | optimizerMap['adam'] = optimizerMap['Adam'];
|
31 | optimizerMap['adamax'] = optimizerMap['Adamax'];
|
32 | optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
|
33 | optimizerMap['sgd'] = optimizerMap['SGD'];
|
34 | if (identifier in optimizerMap) {
|
35 | return optimizerMap[identifier]();
|
36 | }
|
37 | throw new ValueError(`Unknown Optimizer ${identifier}`);
|
38 | }
|
39 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoib3B0aW1pemVycy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtbGF5ZXJzL3NyYy9vcHRpbWl6ZXJzLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7OztHQVFHO0FBRUg7O0dBRUc7QUFFSCxPQUFPLEVBQVksS0FBSyxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFFdkQsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLGtCQUFrQixDQUFDO0FBRXpDLE9BQU8sRUFBQyxVQUFVLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFFcEMsc0JBQXNCO0FBRXRCLDhFQUE4RTtBQUM5RSxrREFBa0Q7QUFDbEQsTUFBTSxVQUFVLFlBQVksQ0FBQyxVQUFrQjtJQUM3QyxNQUFNLFlBQVksR0FBK0M7UUFDL0QsU0FBUyxFQUFFLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxPQUFPLENBQUMsSUFBSSxDQUFDO1FBQ3BDLFVBQVUsRUFBRSxHQUFHLEVBQUUsQ0FBQyxLQUFLLENBQUMsUUFBUSxDQUFDLENBQUMsRUFBRSxJQUFJLEVBQUUsT0FBTyxFQUFFLENBQUM7UUFDcEQsTUFBTSxFQUFFLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxJQUFJLENBQUMsS0FBSyxFQUFFLEdBQUcsRUFBRSxLQUFLLEVBQUUsT0FBTyxFQUFFLENBQUM7UUFDdEQsUUFBUSxFQUFFLEdBQUcsRUFBRSxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsS0FBSyxFQUFFLEdBQUcsRUFBRSxLQUFLLEVBQUUsT0FBTyxFQUFFLEVBQUUsQ0FBQyxDQUFDO1FBQzdELFNBQVMsRUFBRSxHQUFHLEVBQUUsQ0FBQyxLQUFLLENBQUMsT0FBTyxDQUFDLEtBQUssRUFBRSxHQUFHLEVBQUUsQ0FBQyxFQUFFLE9BQU8sRUFBRSxDQUFDO1FBQ3hELEtBQUssRUFBRSxHQUFHLEVBQUUsQ0FBQyxLQUFLLENBQUMsR0FBRyxDQUFDLElBQUksQ0FBQztLQUM3QixDQUFDO0lBQ0YsWUFBWSxDQUFDLFNBQVMsQ0FBQyxHQUFHLFlBQVksQ0FBQyxTQUFTLENBQUMsQ0FBQztJQUNsRCxZQUFZLENBQUMsVUFBVSxDQUFDLEdBQUcsWUFBWSxDQUFDLFVBQVUsQ0FBQyxDQUFDO0lBQ3BELFlBQVksQ0FBQyxNQUFNLENBQUMsR0FBRyxZQUFZLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDNUMsWUFBWSxDQUFDLFFBQVEsQ0FBQyxHQUFHLFlBQVksQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUNoRCxZQUFZLENBQUMsU0FBUyxDQUFDLEdBQUcsWUFBWSxDQUFDLFNBQVMsQ0FBQyxDQUFDO0lBQ2xELFlBQVksQ0FBQyxLQUFLLENBQUMsR0FBRyxZQUFZLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFMUMsSUFBSSxVQUFVLElBQUksWUFBWSxFQUFFO1FBQzlCLE9BQU8sWUFBWSxDQUFDLFVBQVUsQ0FBQyxFQUFFLENBQUM7S0FDbkM7SUFDRCxNQUFNLElBQUksVUFBVSxDQUFDLHFCQUFxQixVQUFVLEVBQUUsQ0FBQyxDQUFDO0FBQzFELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDXG4gKlxuICogVXNlIG9mIHRoaXMgc291cmNlIGNvZGUgaXMgZ292ZXJuZWQgYnkgYW4gTUlULXN0eWxlXG4gKiBsaWNlbnNlIHRoYXQgY2FuIGJlIGZvdW5kIGluIHRoZSBMSUNFTlNFIGZpbGUgb3IgYXRcbiAqIGh0dHBzOi8vb3BlbnNvdXJjZS5vcmcvbGljZW5zZXMvTUlULlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG4vKipcbiAqIE9wdGltaXplcnMuXG4gKi9cblxuaW1wb3J0IHtPcHRpbWl6ZXIsIHRyYWlufSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge2Vwc2lsb259IGZyb20gJy4vYmFja2VuZC9jb21tb24nO1xuXG5pbXBvcnQge1ZhbHVlRXJyb3J9IGZyb20gJy4vZXJyb3JzJztcblxuLy8gQWRkIChkZSlzZXJpYWxpemUoKVxuXG4vLyBQb3J0aW5nIG5vdGU6IFRoaXMgZGl2ZXJnZXMgZnJvbSB0aGUgUHlLZXJhcyBpbXBsZW1lbnRhdGlvbiBhbmQgbWF5IG5lZWQgdG9cbi8vIGNoYW5nZSBiYXNlZCBvbiAoZGUpc2VyaWFsaXphdGlvbiByZXF1aXJlbWVudHMuXG5leHBvcnQgZnVuY3Rpb24gZ2V0T3B0aW1pemVyKGlkZW50aWZpZXI6IHN0cmluZyk6IE9wdGltaXplciB7XG4gIGNvbnN0IG9wdGltaXplck1hcDoge1tvcHRpbWl6ZXJOYW1lOiBzdHJpbmddOiAoKSA9PiBPcHRpbWl6ZXJ9ID0ge1xuICAgICdBZGFncmFkJzogKCkgPT4gdHJhaW4uYWRhZ3JhZCgwLjAxKSxcbiAgICAnQWRhZGVsdGEnOiAoKSA9PiB0cmFpbi5hZGFkZWx0YSgxLCAwLjk1LCBlcHNpbG9uKCkpLFxuICAgICdBZGFtJzogKCkgPT4gdHJhaW4uYWRhbSgwLjAwMSwgMC45LCAwLjk5OSwgZXBzaWxvbigpKSxcbiAgICAnQWRhbWF4JzogKCkgPT4gdHJhaW4uYWRhbWF4KDAuMDAyLCAwLjksIDAuOTk5LCBlcHNpbG9uKCksIDApLFxuICAgICdSTVNQcm9wJzogKCkgPT4gdHJhaW4ucm1zcHJvcCgwLjAwMSwgMC45LCAwLCBlcHNpbG9uKCkpLFxuICAgICdTR0QnOiAoKSA9PiB0cmFpbi5zZ2QoMC4wMSlcbiAgfTtcbiAgb3B0aW1pemVyTWFwWydhZGFncmFkJ10gPSBvcHRpbWl6ZXJNYXBbJ0FkYWdyYWQnXTtcbiAgb3B0aW1pemVyTWFwWydhZGFkZWx0YSddID0gb3B0aW1pemVyTWFwWydBZGFkZWx0YSddO1xuICBvcHRpbWl6ZXJNYXBbJ2FkYW0nXSA9IG9wdGltaXplck1hcFsnQWRhbSddO1xuICBvcHRpbWl6ZXJNYXBbJ2FkYW1heCddID0gb3B0aW1pemVyTWFwWydBZGFtYXgnXTtcbiAgb3B0aW1pemVyTWFwWydybXNwcm9wJ10gPSBvcHRpbWl6ZXJNYXBbJ1JNU1Byb3AnXTtcbiAgb3B0aW1pemVyTWFwWydzZ2QnXSA9IG9wdGltaXplck1hcFsnU0dEJ107XG5cbiAgaWYgKGlkZW50aWZpZXIgaW4gb3B0aW1pemVyTWFwKSB7XG4gICAgcmV0dXJuIG9wdGltaXplck1hcFtpZGVudGlmaWVyXSgpO1xuICB9XG4gIHRocm93IG5ldyBWYWx1ZUVycm9yKGBVbmtub3duIE9wdGltaXplciAke2lkZW50aWZpZXJ9YCk7XG59XG4iXX0= |
\ | No newline at end of file |