UNPKG

13.8 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/* original source: keras/regularizers.py */
11import * as tfc from '@tensorflow/tfjs-core';
12import { abs, add, serialization, sum, tidy, zeros } from '@tensorflow/tfjs-core';
13import * as K from './backend/tfjs_backend';
14import { deserializeKerasObject, serializeKerasObject } from './utils/generic_utils';
15function assertObjectArgs(args) {
16 if (args != null && typeof args !== 'object') {
17 throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
18 `object, but received: ${args}`);
19 }
20}
21/**
22 * Regularizer base class.
23 */
24export class Regularizer extends serialization.Serializable {
25}
26export class L1L2 extends Regularizer {
27 constructor(args) {
28 super();
29 assertObjectArgs(args);
30 this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
31 this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
32 this.hasL1 = this.l1 !== 0;
33 this.hasL2 = this.l2 !== 0;
34 }
35 /**
36 * Porting note: Renamed from __call__.
37 * @param x Variable of which to calculate the regularization score.
38 */
39 apply(x) {
40 return tidy(() => {
41 let regularization = zeros([1]);
42 if (this.hasL1) {
43 regularization = add(regularization, sum(tfc.mul(this.l1, abs(x))));
44 }
45 if (this.hasL2) {
46 regularization =
47 add(regularization, sum(tfc.mul(this.l2, K.square(x))));
48 }
49 return tfc.reshape(regularization, []);
50 });
51 }
52 getConfig() {
53 return { 'l1': this.l1, 'l2': this.l2 };
54 }
55 /** @nocollapse */
56 static fromConfig(cls, config) {
57 return new cls({ l1: config['l1'], l2: config['l2'] });
58 }
59}
60/** @nocollapse */
61L1L2.className = 'L1L2';
62serialization.registerClass(L1L2);
63export function l1(args) {
64 assertObjectArgs(args);
65 return new L1L2({ l1: args != null ? args.l1 : null, l2: 0 });
66}
67export function l2(args) {
68 assertObjectArgs(args);
69 return new L1L2({ l2: args != null ? args.l2 : null, l1: 0 });
70}
71// Maps the JavaScript-like identifier keys to the corresponding keras symbols.
72export const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
73 'l1l2': 'L1L2'
74};
75export function serializeRegularizer(constraint) {
76 return serializeKerasObject(constraint);
77}
78export function deserializeRegularizer(config, customObjects = {}) {
79 return deserializeKerasObject(config, serialization.SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
80}
81export function getRegularizer(identifier) {
82 if (identifier == null) {
83 return null;
84 }
85 if (typeof identifier === 'string') {
86 const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
87 REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
88 identifier;
89 const config = { className, config: {} };
90 return deserializeRegularizer(config);
91 }
92 else if (identifier instanceof Regularizer) {
93 return identifier;
94 }
95 else {
96 return deserializeRegularizer(identifier);
97 }
98}
99//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"regularizers.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/regularizers.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,4CAA4C;AAE5C,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,GAAG,EAAE,GAAG,EAAU,aAAa,EAAE,GAAG,EAAU,IAAI,EAAE,KAAK,EAAC,MAAM,uBAAuB,CAAC;AAChG,OAAO,KAAK,CAAC,MAAM,wBAAwB,CAAC;AAC5C,OAAO,EAAC,sBAAsB,EAAE,oBAAoB,EAAC,MAAM,uBAAuB,CAAC;AAEnF,SAAS,gBAAgB,CAAC,IAA4B;IACpD,IAAI,IAAI,IAAI,IAAI,IAAI,OAAO,IAAI,KAAK,QAAQ,EAAE;QAC5C,MAAM,IAAI,KAAK,CACX,kEAAkE;YAClE,yBAAyB,IAAI,EAAE,CAAC,CAAC;KACtC;AACH,CAAC;AAED;;GAEG;AACH,MAAM,OAAgB,WAAY,SAAQ,aAAa,CAAC,YAAY;CAEnE;AAmBD,MAAM,OAAO,IAAK,SAAQ,WAAW;IAQnC,YAAY,IAAe;QACzB,KAAK,EAAE,CAAC;QAER,gBAAgB,CAAC,IAAI,CAAC,CAAC;QAEvB,IAAI,CAAC,EAAE,GAAG,IAAI,IAAI,IAAI,IAAI,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;QAC3D,IAAI,CAAC,EAAE,GAAG,IAAI,IAAI,IAAI,IAAI,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;QAC3D,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,EAAE,KAAK,CAAC,CAAC;QAC3B,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,EAAE,KAAK,CAAC,CAAC;IAC7B,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,cAAc,GAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACxC,IAAI,IAAI,CAAC,KAAK,EAAE;gBACd,cAAc,GAAG,GAAG,CAAC,cAAc,EAAE,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACrE;YACD,IAAI,IAAI,CAAC,KAAK,EAAE;gBACd,cAAc;oBACV,GAAG,CAAC,cAAc,EAAE,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,EAAE,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAC7D;YACD,OAAO,GAAG,CAAC,OAAO,CAAC,cAAc,EAAE,EAAE,CAAC,CAAC;QACzC,CAAC,CAAC,CAAC;IACL,CAAC;IAED,SAAS;QACP,OAAO,EAAC,IAAI,EAAE,IAAI,CAAC,EAAE,EAAE,IAAI,EAAE,IAAI,CAAC,EAAE,EAAC,CAAC;IACxC,CAAC;IAED,kBAAkB;IAClB,MAAM,CAAU,UAAU,CACtB,GAA6C,EAC7C,MAAgC;QAClC,OAAO,IAAI,GAAG,CAAC,EAAC,EAAE,EAAE,MAAM,CAAC,IAAI,CAAW,EAAE,EAAE,EAAE,MAAM,CAAC,IAAI,CAAW,EAAC,CAAC,CAAC;IAC3E,CAAC;;AA7CD,kBAAkB;AACX,cAAS,GAAG,MAAM,CAAC;AA8C5B,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC,MAAM,UAAU,EAAE,CAAC,IAAa;IAC9B,gBAAgB,CAAC,IAAI,CAAC,CAAC;IACvB,OAAO,IAAI,IAAI,CAAC,EAAC,EAAE,EAAE,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,EAAC,CAAC,CAAC;AAC9D,CAAC;AAED,MAAM,UAAU,EAAE,CAAC,IAAY;IAC7B,gBAAgB,CAAC,IAAI,CAAC,CAAC;IACvB,OAAO,IAAI,IAAI,CAAC,EAAC,EAAE,EAAE,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,EAAC,CAAC,CAAC;AAC9D,CAAC;AAKD,+EAA+E;AAC/E,MAAM,CAAC,MAAM,0CAA0C,GACD;IAChD,MAAM,EAAE,MAAM;CACf,CAAC;AAEN,MAAM,UAAU,oBAAoB,CAAC,UAAuB;IAE1D,OAAO,oBAAoB,CAAC,UAAU,CAAC,CAAC;AAC1C,CAAC;AAED,MAAM,UAAU,sBAAsB,CAClC,MAAgC,EAChC,gBAA0C,EAAE;IAC9C,OAAO,sBAAsB,CACzB,MAAM,EAAE,aAAa,CAAC,gBAAgB,CAAC,MAAM,EAAE,CAAC,YAAY,EAC5D,aAAa,EAAE,aAAa,CAAC,CAAC;AACpC,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,UAEW;IACxC,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,OAAO,IAAI,CAAC;KACb;IACD,IAAI,OAAO,UAAU,KAAK,QAAQ,EAAE;QAClC,MAAM,SAAS,GAAG,UAAU,IAAI,0CAA0C,CAAC,CAAC;YACxE,0CAA0C,CAAC,UAAU,CAAC,CAAC,CAAC;YACxD,UAAU,CAAC;QACf,MAAM,MAAM,GAAG,EAAC,SAAS,EAAE,MAAM,EAAE,EAAE,EAAC,CAAC;QACvC,OAAO,sBAAsB,CAAC,MAAM,CAAC,CAAC;KACvC;SAAM,IAAI,UAAU,YAAY,WAAW,EAAE;QAC5C,OAAO,UAAU,CAAC;KACnB;SAAM;QACL,OAAO,sBAAsB,CAAC,UAAU,CAAC,CAAC;KAC3C;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* original source: keras/regularizers.py */\n\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {abs, add, Scalar, serialization, sum, Tensor, tidy, zeros} from '@tensorflow/tfjs-core';\nimport * as K from './backend/tfjs_backend';\nimport {deserializeKerasObject, serializeKerasObject} from './utils/generic_utils';\n\nfunction assertObjectArgs(args: L1Args|L2Args|L1L2Args): void {\n  if (args != null && typeof args !== 'object') {\n    throw new Error(\n        `Argument to L1L2 regularizer's constructor is expected to be an ` +\n        `object, but received: ${args}`);\n  }\n}\n\n/**\n * Regularizer base class.\n */\nexport abstract class Regularizer extends serialization.Serializable {\n  abstract apply(x: Tensor): Scalar;\n}\n\nexport interface L1L2Args {\n  /** L1 regularization rate. Defaults to 0.01. */\n  l1?: number;\n  /** L2 regularization rate. Defaults to 0.01. */\n  l2?: number;\n}\n\nexport interface L1Args {\n  /** L1 regularization rate. Defaults to 0.01. */\n  l1: number;\n}\n\nexport interface L2Args {\n  /** L2 regularization rate. Defaults to 0.01. */\n  l2: number;\n}\n\nexport class L1L2 extends Regularizer {\n  /** @nocollapse */\n  static className = 'L1L2';\n\n  private readonly l1: number;\n  private readonly l2: number;\n  private readonly hasL1: boolean;\n  private readonly hasL2: boolean;\n  constructor(args?: L1L2Args) {\n    super();\n\n    assertObjectArgs(args);\n\n    this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;\n    this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;\n    this.hasL1 = this.l1 !== 0;\n    this.hasL2 = this.l2 !== 0;\n  }\n\n  /**\n   * Porting note: Renamed from __call__.\n   * @param x Variable of which to calculate the regularization score.\n   */\n  apply(x: Tensor): Scalar {\n    return tidy(() => {\n      let regularization: Tensor = zeros([1]);\n      if (this.hasL1) {\n        regularization = add(regularization, sum(tfc.mul(this.l1, abs(x))));\n      }\n      if (this.hasL2) {\n        regularization =\n            add(regularization, sum(tfc.mul(this.l2, K.square(x))));\n      }\n      return tfc.reshape(regularization, []);\n    });\n  }\n\n  getConfig(): serialization.ConfigDict {\n    return {'l1': this.l1, 'l2': this.l2};\n  }\n\n  /** @nocollapse */\n  static override fromConfig<T extends serialization.Serializable>(\n      cls: serialization.SerializableConstructor<T>,\n      config: serialization.ConfigDict): T {\n    return new cls({l1: config['l1'] as number, l2: config['l2'] as number});\n  }\n}\nserialization.registerClass(L1L2);\n\nexport function l1(args?: L1Args) {\n  assertObjectArgs(args);\n  return new L1L2({l1: args != null ? args.l1 : null, l2: 0});\n}\n\nexport function l2(args: L2Args) {\n  assertObjectArgs(args);\n  return new L1L2({l2: args != null ? args.l2 : null, l1: 0});\n}\n\n/** @docinline */\nexport type RegularizerIdentifier = 'l1l2'|string;\n\n// Maps the JavaScript-like identifier keys to the corresponding keras symbols.\nexport const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP:\n    {[identifier in RegularizerIdentifier]: string} = {\n      'l1l2': 'L1L2'\n    };\n\nexport function serializeRegularizer(constraint: Regularizer):\n    serialization.ConfigDictValue {\n  return serializeKerasObject(constraint);\n}\n\nexport function deserializeRegularizer(\n    config: serialization.ConfigDict,\n    customObjects: serialization.ConfigDict = {}): Regularizer {\n  return deserializeKerasObject(\n      config, serialization.SerializationMap.getMap().classNameMap,\n      customObjects, 'regularizer');\n}\n\nexport function getRegularizer(identifier: RegularizerIdentifier|\n                               serialization.ConfigDict|\n                               Regularizer): Regularizer {\n  if (identifier == null) {\n    return null;\n  }\n  if (typeof identifier === 'string') {\n    const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?\n        REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :\n        identifier;\n    const config = {className, config: {}};\n    return deserializeRegularizer(config);\n  } else if (identifier instanceof Regularizer) {\n    return identifier;\n  } else {\n    return deserializeRegularizer(identifier);\n  }\n}\n"]}
\No newline at end of file