UNPKG

2.8 kBJavaScriptView Raw
1"use strict";
2/**
3 * @license
4 * Copyright 2020 Google LLC. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 * =============================================================================
17 */
18Object.defineProperty(exports, "__esModule", { value: true });
19var tfjs_1 = require("@tensorflow/tfjs");
20var nodejs_kernel_backend_1 = require("../nodejs_kernel_backend");
21exports.fusedBatchNormConfig = {
22 kernelName: tfjs_1.FusedBatchNorm,
23 backendName: 'tensorflow',
24 kernelFunc: function (args) {
25 var _a = args.inputs, x = _a.x, mean = _a.mean, variance = _a.variance;
26 var _b = args.inputs, scale = _b.scale, offset = _b.offset;
27 var backend = args.backend;
28 var varianceEpsilon = args.attrs.varianceEpsilon;
29 return tfjs_1.tidy(function () {
30 if (mean.rank > 1) {
31 // Fused batch norm doesn't work with high-dim mean/var/scale/offset.
32 var inv = tfjs_1.rsqrt(tfjs_1.add(variance, tfjs_1.scalar(varianceEpsilon)));
33 if (scale != null) {
34 inv = tfjs_1.mul(inv, scale);
35 }
36 var xNorm = tfjs_1.mul(tfjs_1.sub(x, mean), inv);
37 return offset != null ? tfjs_1.add(xNorm, offset) : xNorm;
38 }
39 var dataFormat = 'NHWC';
40 var depth = x.shape[3];
41 var opAttrs = [
42 nodejs_kernel_backend_1.createTensorsTypeOpAttr('T', x.dtype),
43 {
44 name: 'epsilon',
45 type: backend.binding.TF_ATTR_FLOAT,
46 value: varianceEpsilon
47 },
48 {
49 name: 'data_format',
50 type: backend.binding.TF_ATTR_STRING,
51 value: dataFormat
52 },
53 { name: 'is_training', type: backend.binding.TF_ATTR_BOOL, value: false },
54 ];
55 var numOutputs = 5;
56 if (scale == null) {
57 scale = tfjs_1.fill([depth], 1);
58 }
59 if (offset == null) {
60 offset = tfjs_1.fill([depth], 0);
61 }
62 return backend.executeMultipleOutputs(tfjs_1.FusedBatchNorm, opAttrs, [x, scale, offset, mean, variance], numOutputs)[0];
63 });
64 }
65};