1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | /// <amd-module name="@tensorflow/tfjs-layers/dist/constraints" />
|
11 | import { serialization, Tensor } from '@tensorflow/tfjs-core';
|
12 | /**
|
13 | * Base class for functions that impose constraints on weight values
|
14 | *
|
15 | * @doc {
|
16 | * heading: 'Constraints',
|
17 | * subheading: 'Classes',
|
18 | * namespace: 'constraints'
|
19 | * }
|
20 | */
|
21 | export declare abstract class Constraint extends serialization.Serializable {
|
22 | abstract apply(w: Tensor): Tensor;
|
23 | getConfig(): serialization.ConfigDict;
|
24 | }
|
25 | export interface MaxNormArgs {
|
26 | /**
|
27 | * Maximum norm for incoming weights
|
28 | */
|
29 | maxValue?: number;
|
30 | /**
|
31 | * Axis along which to calculate norms.
|
32 | *
|
33 | * For instance, in a `Dense` layer the weight matrix
|
34 | * has shape `[inputDim, outputDim]`,
|
35 | * set `axis` to `0` to constrain each weight vector
|
36 | * of length `[inputDim,]`.
|
37 | * In a `Conv2D` layer with `dataFormat="channels_last"`,
|
38 | * the weight tensor has shape
|
39 | * `[rows, cols, inputDepth, outputDepth]`,
|
40 | * set `axis` to `[0, 1, 2]`
|
41 | * to constrain the weights of each filter tensor of size
|
42 | * `[rows, cols, inputDepth]`.
|
43 | */
|
44 | axis?: number;
|
45 | }
|
46 | export declare class MaxNorm extends Constraint {
|
47 | /** @nocollapse */
|
48 | static readonly className = "MaxNorm";
|
49 | private maxValue;
|
50 | private axis;
|
51 | private readonly defaultMaxValue;
|
52 | private readonly defaultAxis;
|
53 | constructor(args: MaxNormArgs);
|
54 | apply(w: Tensor): Tensor;
|
55 | getConfig(): serialization.ConfigDict;
|
56 | }
|
57 | export interface UnitNormArgs {
|
58 | /**
|
59 | * Axis along which to calculate norms.
|
60 | *
|
61 | * For instance, in a `Dense` layer the weight matrix
|
62 | * has shape `[inputDim, outputDim]`,
|
63 | * set `axis` to `0` to constrain each weight vector
|
64 | * of length `[inputDim,]`.
|
65 | * In a `Conv2D` layer with `dataFormat="channels_last"`,
|
66 | * the weight tensor has shape
|
67 | * [rows, cols, inputDepth, outputDepth]`,
|
68 | * set `axis` to `[0, 1, 2]`
|
69 | * to constrain the weights of each filter tensor of size
|
70 | * `[rows, cols, inputDepth]`.
|
71 | */
|
72 | axis?: number;
|
73 | }
|
74 | export declare class UnitNorm extends Constraint {
|
75 | /** @nocollapse */
|
76 | static readonly className = "UnitNorm";
|
77 | private axis;
|
78 | private readonly defaultAxis;
|
79 | constructor(args: UnitNormArgs);
|
80 | apply(w: Tensor): Tensor;
|
81 | getConfig(): serialization.ConfigDict;
|
82 | }
|
83 | export declare class NonNeg extends Constraint {
|
84 | /** @nocollapse */
|
85 | static readonly className = "NonNeg";
|
86 | apply(w: Tensor): Tensor;
|
87 | }
|
88 | export interface MinMaxNormArgs {
|
89 | /**
|
90 | * Minimum norm for incoming weights
|
91 | */
|
92 | minValue?: number;
|
93 | /**
|
94 | * Maximum norm for incoming weights
|
95 | */
|
96 | maxValue?: number;
|
97 | /**
|
98 | * Axis along which to calculate norms.
|
99 | * For instance, in a `Dense` layer the weight matrix
|
100 | * has shape `[inputDim, outputDim]`,
|
101 | * set `axis` to `0` to constrain each weight vector
|
102 | * of length `[inputDim,]`.
|
103 | * In a `Conv2D` layer with `dataFormat="channels_last"`,
|
104 | * the weight tensor has shape
|
105 | * `[rows, cols, inputDepth, outputDepth]`,
|
106 | * set `axis` to `[0, 1, 2]`
|
107 | * to constrain the weights of each filter tensor of size
|
108 | * `[rows, cols, inputDepth]`.
|
109 | */
|
110 | axis?: number;
|
111 | /**
|
112 | * Rate for enforcing the constraint: weights will be rescaled to yield:
|
113 | * `(1 - rate) * norm + rate * norm.clip(minValue, maxValue)`.
|
114 | * Effectively, this means that rate=1.0 stands for strict
|
115 | * enforcement of the constraint, while rate<1.0 means that
|
116 | * weights will be rescaled at each step to slowly move
|
117 | * towards a value inside the desired interval.
|
118 | */
|
119 | rate?: number;
|
120 | }
|
121 | export declare class MinMaxNorm extends Constraint {
|
122 | /** @nocollapse */
|
123 | static readonly className = "MinMaxNorm";
|
124 | private minValue;
|
125 | private maxValue;
|
126 | private rate;
|
127 | private axis;
|
128 | private readonly defaultMinValue;
|
129 | private readonly defaultMaxValue;
|
130 | private readonly defaultRate;
|
131 | private readonly defaultAxis;
|
132 | constructor(args: MinMaxNormArgs);
|
133 | apply(w: Tensor): Tensor;
|
134 | getConfig(): serialization.ConfigDict;
|
135 | }
|
136 | /** @docinline */
|
137 | export declare type ConstraintIdentifier = 'maxNorm' | 'minMaxNorm' | 'nonNeg' | 'unitNorm' | string;
|
138 | export declare const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
|
139 | [identifier in ConstraintIdentifier]: string;
|
140 | };
|
141 | export declare function serializeConstraint(constraint: Constraint): serialization.ConfigDictValue;
|
142 | export declare function deserializeConstraint(config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): Constraint;
|
143 | export declare function getConstraint(identifier: ConstraintIdentifier | serialization.ConfigDict | Constraint): Constraint;
|