UNPKG

2.02 kBJavaScriptView Raw
1import { convertToTensor } from '../tensor_util_env';
2import * as util from '../util';
3import { batchNorm } from './batchnorm';
4import { op } from './operation';
5/**
6 * Batch normalization, strictly for 2D. For the more relaxed version, see
7 * `tf.batchNorm`.
8 *
9 * @param x The input Tensor.
10 * @param mean A mean Tensor.
11 * @param variance A variance Tensor.
12 * @param offset An offset Tensor.
13 * @param scale A scale Tensor.
14 * @param varianceEpsilon A small float number to avoid dividing by 0.
15 */
16function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
17 const $x = convertToTensor(x, 'x', 'batchNorm');
18 const $mean = convertToTensor(mean, 'mean', 'batchNorm');
19 const $variance = convertToTensor(variance, 'variance', 'batchNorm');
20 let $scale;
21 if (scale != null) {
22 $scale = convertToTensor(scale, 'scale', 'batchNorm');
23 }
24 let $offset;
25 if (offset != null) {
26 $offset = convertToTensor(offset, 'offset', 'batchNorm');
27 }
28 util.assert($x.rank === 2, () => `Error in batchNorm2D: x must be rank 2 but got rank ` +
29 `${$x.rank}.`);
30 util.assert($mean.rank === 2 || $mean.rank === 1, () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` +
31 `got rank ${$mean.rank}.`);
32 util.assert($variance.rank === 2 || $variance.rank === 1, () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` +
33 `but got rank ${$variance.rank}.`);
34 if ($scale != null) {
35 util.assert($scale.rank === 2 || $scale.rank === 1, () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` +
36 `but got rank ${$scale.rank}.`);
37 }
38 if ($offset != null) {
39 util.assert($offset.rank === 2 || $offset.rank === 1, () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` +
40 `but got rank ${$offset.rank}.`);
41 }
42 return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
43}
44export const batchNorm2d = op({ batchNorm2d_ });
45//# sourceMappingURL=batchnorm2d.js.map
\No newline at end of file