UNPKG

11.7 kBJavaScriptView Raw
1"use strict";
2/**
3 * @license
4 * Copyright 2018 Google LLC. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 * =============================================================================
17 */
18Object.defineProperty(exports, "__esModule", { value: true });
19var path = require("path");
20// tslint:disable-next-line:no-require-imports
21var binary = require('@mapbox/node-pre-gyp');
22var bindingPath = binary.find(path.resolve(path.join(__dirname, '../package.json')));
23// tslint:disable-next-line:no-require-imports
24var bindings = require(bindingPath);
25var binding = bindings;
26describe('Exposes TF_DataType enum values', function () {
27 it('contains TF_FLOAT', function () {
28 expect(binding.TF_FLOAT).toEqual(1);
29 });
30 it('contains TF_INT32', function () {
31 expect(binding.TF_INT32).toEqual(3);
32 });
33 it('contains TF_BOOL', function () {
34 expect(binding.TF_BOOL).toEqual(10);
35 });
36 it('contains TF_COMPLEX64', function () {
37 expect(binding.TF_COMPLEX64).toEqual(8);
38 });
39 it('contains TF_STRING', function () {
40 expect(binding.TF_STRING).toEqual(7);
41 });
42});
43describe('Exposes TF_AttrType enum values', function () {
44 it('contains TF_ATTR_STRING', function () {
45 expect(binding.TF_ATTR_STRING).toEqual(0);
46 });
47 it('contains TF_ATTR_INT', function () {
48 expect(binding.TF_ATTR_INT).toEqual(1);
49 });
50 it('contains TF_ATTR_FLOAT', function () {
51 expect(binding.TF_ATTR_FLOAT).toEqual(2);
52 });
53 it('contains TF_ATTR_BOOL', function () {
54 expect(binding.TF_ATTR_BOOL).toEqual(3);
55 });
56 it('contains TF_ATTR_TYPE', function () {
57 expect(binding.TF_ATTR_TYPE).toEqual(4);
58 });
59 it('contains TF_ATTR_SHAPE', function () {
60 expect(binding.TF_ATTR_SHAPE).toEqual(5);
61 });
62});
63describe('Exposes TF Version', function () {
64 it('contains a version string', function () {
65 expect(binding.TF_Version).toBeDefined();
66 });
67});
68describe('tensor management', function () {
69 it('Creates and deletes a valid tensor', function () {
70 var values = new Int32Array([1, 2]);
71 var id = binding.createTensor([2], binding.TF_INT32, values);
72 expect(id).toBeDefined();
73 binding.deleteTensor(id);
74 });
75 it('throws exception when shape does not match data', function () {
76 expect(function () {
77 binding.createTensor([2], binding.TF_INT32, new Int32Array([1, 2, 3]));
78 }).toThrowError();
79 expect(function () {
80 binding.createTensor([4], binding.TF_INT32, new Int32Array([1, 2, 3]));
81 }).toThrowError();
82 });
83 it('throws exception with invalid dtype', function () {
84 expect(function () {
85 // tslint:disable-next-line:no-unused-expression
86 binding.createTensor([1], 1000, new Int32Array([1]));
87 }).toThrowError();
88 });
89 it('works with 0-dim tensors', function () {
90 // Reduce op (e.g 'Max') will produce a 0-dim TFE_Tensor.
91 var inputId = binding.createTensor([3], binding.TF_INT32, new Int32Array([1, 2, 3]));
92 var axesId = binding.createTensor([1], binding.TF_INT32, new Int32Array([0]));
93 var attrs = [
94 { name: 'keep_dims', type: binding.TF_ATTR_BOOL, value: false },
95 { name: 'T', type: binding.TF_ATTR_TYPE, value: binding.TF_INT32 },
96 { name: 'Tidx', type: binding.TF_ATTR_TYPE, value: binding.TF_INT32 }
97 ];
98 var outputMetadata = binding.executeOp('Max', attrs, [inputId, axesId], 1);
99 expect(outputMetadata.length).toBe(1);
100 expect(outputMetadata[0].id).toBeDefined();
101 expect(outputMetadata[0].shape).toEqual([]);
102 expect(outputMetadata[0].dtype).toEqual(binding.TF_INT32);
103 expect(binding.tensorDataSync(outputMetadata[0].id))
104 .toEqual(new Int32Array([3]));
105 });
106});
107describe('executeOp', function () {
108 var name = 'MatMul';
109 var matMulOpAttrs = [
110 { name: 'transpose_a', type: binding.TF_ATTR_BOOL, value: false },
111 { name: 'transpose_b', type: binding.TF_ATTR_BOOL, value: false },
112 { name: 'T', type: binding.TF_ATTR_TYPE, value: binding.TF_FLOAT }
113 ];
114 var aId = binding.createTensor([2, 2], binding.TF_FLOAT, new Float32Array([1, 2, 3, 4]));
115 var bId = binding.createTensor([2, 2], binding.TF_FLOAT, new Float32Array([4, 3, 2, 1]));
116 var matMulInput = [aId, bId];
117 it('throws exception with invalid Op Name', function () {
118 expect(function () {
119 binding.executeOp(null, [], [], null);
120 }).toThrowError();
121 });
122 it('throws exception with invalid TFEOpAttr', function () {
123 expect(function () {
124 binding.executeOp('Equal', null, [], null);
125 }).toThrowError();
126 });
127 it('throws excpetion with invalid inputs', function () {
128 expect(function () {
129 binding.executeOp(name, matMulOpAttrs, [], null);
130 }).toThrowError();
131 });
132 it('throws exception with invalid output number', function () {
133 expect(function () {
134 binding.executeOp(name, matMulOpAttrs, matMulInput, null);
135 }).toThrowError();
136 });
137 it('throws exception with invalid TF_ATTR_STRING op attr', function () {
138 expect(function () {
139 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: null }];
140 binding.executeOp(name, badOpAttrs, matMulInput, 1);
141 }).toThrowError();
142 expect(function () {
143 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: false }];
144 binding.executeOp(name, badOpAttrs, matMulInput, 1);
145 }).toThrowError();
146 expect(function () {
147 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: 1 }];
148 binding.executeOp(name, badOpAttrs, matMulInput, 1);
149 }).toThrowError();
150 expect(function () {
151 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: new Object() }];
152 binding.executeOp(name, badOpAttrs, matMulInput, 1);
153 }).toThrowError();
154 expect(function () {
155 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_STRING, value: [1, 2, 3] }];
156 binding.executeOp(name, badOpAttrs, matMulInput, 1);
157 }).toThrowError();
158 });
159 it('throws exception with invalid TF_ATTR_INT op attr', function () {
160 expect(function () {
161 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: null }];
162 binding.executeOp(name, badOpAttrs, matMulInput, 1);
163 }).toThrowError();
164 expect(function () {
165 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: false }];
166 binding.executeOp(name, badOpAttrs, matMulInput, 1);
167 }).toThrowError();
168 expect(function () {
169 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: new Object() }];
170 binding.executeOp(name, badOpAttrs, matMulInput, 1);
171 }).toThrowError();
172 expect(function () {
173 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_INT, value: 'test' }];
174 binding.executeOp(name, badOpAttrs, matMulInput, 1);
175 }).toThrowError();
176 });
177 it('throws exception with invalid TF_ATTR_FLOAT op attr', function () {
178 expect(function () {
179 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: null }];
180 binding.executeOp(name, badOpAttrs, matMulInput, 1);
181 }).toThrowError();
182 expect(function () {
183 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: false }];
184 binding.executeOp(name, badOpAttrs, matMulInput, 1);
185 }).toThrowError();
186 expect(function () {
187 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: new Object() }];
188 binding.executeOp(name, badOpAttrs, matMulInput, 1);
189 }).toThrowError();
190 expect(function () {
191 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_FLOAT, value: 'test' }];
192 binding.executeOp(name, badOpAttrs, matMulInput, 1);
193 }).toThrowError();
194 });
195 it('throws exception with invalid TF_ATTR_BOOL op attr', function () {
196 expect(function () {
197 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: null }];
198 binding.executeOp(name, badOpAttrs, matMulInput, 1);
199 }).toThrowError();
200 expect(function () {
201 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: 10 }];
202 binding.executeOp(name, badOpAttrs, matMulInput, 1);
203 }).toThrowError();
204 expect(function () {
205 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: new Object() }];
206 binding.executeOp(name, badOpAttrs, matMulInput, 1);
207 }).toThrowError();
208 expect(function () {
209 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: 'test' }];
210 binding.executeOp(name, badOpAttrs, matMulInput, 1);
211 }).toThrowError();
212 expect(function () {
213 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_BOOL, value: [1, 2, 3] }];
214 binding.executeOp(name, badOpAttrs, matMulInput, 1);
215 }).toThrowError();
216 });
217 it('throws exception with invalid TF_ATTR_TYPE op attr', function () {
218 expect(function () {
219 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: null }];
220 binding.executeOp(name, badOpAttrs, matMulInput, 1);
221 }).toThrowError();
222 expect(function () {
223 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: new Object() }];
224 binding.executeOp(name, badOpAttrs, matMulInput, 1);
225 }).toThrowError();
226 expect(function () {
227 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: 'test' }];
228 binding.executeOp(name, badOpAttrs, matMulInput, 1);
229 }).toThrowError();
230 expect(function () {
231 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: [1, 2, 3] }];
232 binding.executeOp(name, badOpAttrs, matMulInput, 1);
233 }).toThrowError();
234 });
235 it('throws exception with invalid TF_ATTR_SHAPE op attr', function () {
236 expect(function () {
237 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: null }];
238 binding.executeOp(name, badOpAttrs, matMulInput, 1);
239 }).toThrowError();
240 expect(function () {
241 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: new Object() }];
242 binding.executeOp(name, badOpAttrs, matMulInput, 1);
243 }).toThrowError();
244 expect(function () {
245 var badOpAttrs = [{ name: 'T', type: binding.TF_ATTR_TYPE, value: 'test' }];
246 binding.executeOp(name, badOpAttrs, matMulInput, 1);
247 }).toThrowError();
248 });
249 it('should work for matmul', function () {
250 var output = binding.executeOp(name, matMulOpAttrs, matMulInput, 1);
251 expect(binding.tensorDataSync(output[0].id)).toEqual(new Float32Array([
252 8, 5, 20, 13
253 ]));
254 });
255});