UNPKG

2.13 kBJavaScriptView Raw
1"use strict";
2/**
3 * @license
4 * Copyright 2020 Google LLC. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 * =============================================================================
17 */
18Object.defineProperty(exports, "__esModule", { value: true });
19var tfjs_1 = require("@tensorflow/tfjs");
20exports.fillConfig = {
21 kernelName: tfjs_1.Fill,
22 backendName: 'tensorflow',
23 kernelFunc: function (args) {
24 var backend = args.backend;
25 var _a = args.attrs, shape = _a.shape, value = _a.value;
26 var dtype = args.attrs.dtype;
27 // TODO(cais, nkreeger): Investigate whether backend can be made into
28 // a dtype helper method. The underlying op kernel doesn't accept undefined
29 // or null dtype.
30 if (dtype == null) {
31 if (typeof value === 'number') {
32 dtype = 'float32';
33 }
34 else {
35 dtype = 'string';
36 }
37 }
38 var shapeTensor = tfjs_1.tensor1d(shape, 'int32');
39 var valueTensor = tfjs_1.scalar(value, dtype);
40 var opAttrs = [
41 {
42 name: 'T',
43 type: backend.binding.TF_ATTR_TYPE,
44 value: backend.getDTypeInteger(dtype)
45 },
46 {
47 name: 'index_type',
48 type: backend.binding.TF_ATTR_TYPE,
49 value: backend.binding.TF_INT32
50 }
51 ];
52 var res = backend.executeSingleOutput(tfjs_1.Fill, opAttrs, [shapeTensor, valueTensor]);
53 shapeTensor.dispose();
54 valueTensor.dispose();
55 return res;
56 }
57};