UNPKG

5 kBTypeScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC
4 *
5 * Use of this source code is governed by an MIT-style
6 * license that can be found in the LICENSE file or at
7 * https://opensource.org/licenses/MIT.
8 * =============================================================================
9 */
10/// <amd-module name="@tensorflow/tfjs-layers/dist/constraints" />
11import { serialization, Tensor } from '@tensorflow/tfjs-core';
12/**
13 * Base class for functions that impose constraints on weight values
14 *
15 * @doc {
16 * heading: 'Constraints',
17 * subheading: 'Classes',
18 * namespace: 'constraints'
19 * }
20 */
21export declare abstract class Constraint extends serialization.Serializable {
22 abstract apply(w: Tensor): Tensor;
23 getConfig(): serialization.ConfigDict;
24}
25export interface MaxNormArgs {
26 /**
27 * Maximum norm for incoming weights
28 */
29 maxValue?: number;
30 /**
31 * Axis along which to calculate norms.
32 *
33 * For instance, in a `Dense` layer the weight matrix
34 * has shape `[inputDim, outputDim]`,
35 * set `axis` to `0` to constrain each weight vector
36 * of length `[inputDim,]`.
37 * In a `Conv2D` layer with `dataFormat="channels_last"`,
38 * the weight tensor has shape
39 * `[rows, cols, inputDepth, outputDepth]`,
40 * set `axis` to `[0, 1, 2]`
41 * to constrain the weights of each filter tensor of size
42 * `[rows, cols, inputDepth]`.
43 */
44 axis?: number;
45}
46export declare class MaxNorm extends Constraint {
47 /** @nocollapse */
48 static readonly className = "MaxNorm";
49 private maxValue;
50 private axis;
51 private readonly defaultMaxValue;
52 private readonly defaultAxis;
53 constructor(args: MaxNormArgs);
54 apply(w: Tensor): Tensor;
55 getConfig(): serialization.ConfigDict;
56}
57export interface UnitNormArgs {
58 /**
59 * Axis along which to calculate norms.
60 *
61 * For instance, in a `Dense` layer the weight matrix
62 * has shape `[inputDim, outputDim]`,
63 * set `axis` to `0` to constrain each weight vector
64 * of length `[inputDim,]`.
65 * In a `Conv2D` layer with `dataFormat="channels_last"`,
66 * the weight tensor has shape
67 * [rows, cols, inputDepth, outputDepth]`,
68 * set `axis` to `[0, 1, 2]`
69 * to constrain the weights of each filter tensor of size
70 * `[rows, cols, inputDepth]`.
71 */
72 axis?: number;
73}
74export declare class UnitNorm extends Constraint {
75 /** @nocollapse */
76 static readonly className = "UnitNorm";
77 private axis;
78 private readonly defaultAxis;
79 constructor(args: UnitNormArgs);
80 apply(w: Tensor): Tensor;
81 getConfig(): serialization.ConfigDict;
82}
83export declare class NonNeg extends Constraint {
84 /** @nocollapse */
85 static readonly className = "NonNeg";
86 apply(w: Tensor): Tensor;
87}
88export interface MinMaxNormArgs {
89 /**
90 * Minimum norm for incoming weights
91 */
92 minValue?: number;
93 /**
94 * Maximum norm for incoming weights
95 */
96 maxValue?: number;
97 /**
98 * Axis along which to calculate norms.
99 * For instance, in a `Dense` layer the weight matrix
100 * has shape `[inputDim, outputDim]`,
101 * set `axis` to `0` to constrain each weight vector
102 * of length `[inputDim,]`.
103 * In a `Conv2D` layer with `dataFormat="channels_last"`,
104 * the weight tensor has shape
105 * `[rows, cols, inputDepth, outputDepth]`,
106 * set `axis` to `[0, 1, 2]`
107 * to constrain the weights of each filter tensor of size
108 * `[rows, cols, inputDepth]`.
109 */
110 axis?: number;
111 /**
112 * Rate for enforcing the constraint: weights will be rescaled to yield:
113 * `(1 - rate) * norm + rate * norm.clip(minValue, maxValue)`.
114 * Effectively, this means that rate=1.0 stands for strict
115 * enforcement of the constraint, while rate<1.0 means that
116 * weights will be rescaled at each step to slowly move
117 * towards a value inside the desired interval.
118 */
119 rate?: number;
120}
121export declare class MinMaxNorm extends Constraint {
122 /** @nocollapse */
123 static readonly className = "MinMaxNorm";
124 private minValue;
125 private maxValue;
126 private rate;
127 private axis;
128 private readonly defaultMinValue;
129 private readonly defaultMaxValue;
130 private readonly defaultRate;
131 private readonly defaultAxis;
132 constructor(args: MinMaxNormArgs);
133 apply(w: Tensor): Tensor;
134 getConfig(): serialization.ConfigDict;
135}
136/** @docinline */
137export declare type ConstraintIdentifier = 'maxNorm' | 'minMaxNorm' | 'nonNeg' | 'unitNorm' | string;
138export declare const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
139 [identifier in ConstraintIdentifier]: string;
140};
141export declare function serializeConstraint(constraint: Constraint): serialization.ConfigDictValue;
142export declare function deserializeConstraint(config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): Constraint;
143export declare function getConstraint(identifier: ConstraintIdentifier | serialization.ConfigDict | Constraint): Constraint;