1 | import { convertToTensor } from '../tensor_util_env';
|
2 | import * as util from '../util';
|
3 | import { batchNorm } from './batchnorm';
|
4 | import { op } from './operation';
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 | function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
|
17 | const $x = convertToTensor(x, 'x', 'batchNorm');
|
18 | const $mean = convertToTensor(mean, 'mean', 'batchNorm');
|
19 | const $variance = convertToTensor(variance, 'variance', 'batchNorm');
|
20 | let $scale;
|
21 | if (scale != null) {
|
22 | $scale = convertToTensor(scale, 'scale', 'batchNorm');
|
23 | }
|
24 | let $offset;
|
25 | if (offset != null) {
|
26 | $offset = convertToTensor(offset, 'offset', 'batchNorm');
|
27 | }
|
28 | util.assert($x.rank === 2, () => `Error in batchNorm2D: x must be rank 2 but got rank ` +
|
29 | `${$x.rank}.`);
|
30 | util.assert($mean.rank === 2 || $mean.rank === 1, () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` +
|
31 | `got rank ${$mean.rank}.`);
|
32 | util.assert($variance.rank === 2 || $variance.rank === 1, () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` +
|
33 | `but got rank ${$variance.rank}.`);
|
34 | if ($scale != null) {
|
35 | util.assert($scale.rank === 2 || $scale.rank === 1, () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` +
|
36 | `but got rank ${$scale.rank}.`);
|
37 | }
|
38 | if ($offset != null) {
|
39 | util.assert($offset.rank === 2 || $offset.rank === 1, () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` +
|
40 | `but got rank ${$offset.rank}.`);
|
41 | }
|
42 | return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
|
43 | }
|
44 | export const batchNorm2d = op({ batchNorm2d_ });
|
45 |
|
\ | No newline at end of file |