1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose, expectArraysEqual } from '../test_util';
|
20 | describeWithFlags('onesLike', ALL_ENVS, () => {
|
21 | it('1D default dtype', async () => {
|
22 | const a = tf.tensor1d([1, 2, 3]);
|
23 | const b = tf.onesLike(a);
|
24 | expect(b.dtype).toBe('float32');
|
25 | expect(b.shape).toEqual([3]);
|
26 | expectArraysClose(await b.data(), [1, 1, 1]);
|
27 | });
|
28 | it('chainable 1D default dtype', async () => {
|
29 | const a = tf.tensor1d([1, 2, 3]);
|
30 | const b = a.onesLike();
|
31 | expect(b.dtype).toBe('float32');
|
32 | expect(b.shape).toEqual([3]);
|
33 | expectArraysClose(await b.data(), [1, 1, 1]);
|
34 | });
|
35 | it('1D float32 dtype', async () => {
|
36 | const a = tf.tensor1d([1, 2, 3], 'float32');
|
37 | const b = tf.onesLike(a);
|
38 | expect(b.dtype).toBe('float32');
|
39 | expect(b.shape).toEqual([3]);
|
40 | expectArraysClose(await b.data(), [1, 1, 1]);
|
41 | });
|
42 | it('1D int32 dtype', async () => {
|
43 | const a = tf.tensor1d([1, 2, 3], 'int32');
|
44 | const b = tf.onesLike(a);
|
45 | expect(b.dtype).toBe('int32');
|
46 | expect(b.shape).toEqual([3]);
|
47 | expectArraysEqual(await b.data(), [1, 1, 1]);
|
48 | });
|
49 | it('1D bool dtype', async () => {
|
50 | const a = tf.tensor1d([1, 2, 3], 'bool');
|
51 | const b = tf.onesLike(a);
|
52 | expect(b.dtype).toBe('bool');
|
53 | expect(b.shape).toEqual([3]);
|
54 | expectArraysEqual(await b.data(), [1, 1, 1]);
|
55 | });
|
56 | it('1D complex dtype', async () => {
|
57 | const real = tf.tensor1d([1, 2, 3], 'float32');
|
58 | const imag = tf.tensor1d([1, 2, 3], 'float32');
|
59 | const a = tf.complex(real, imag);
|
60 | const b = tf.onesLike(a);
|
61 | expect(b.dtype).toBe('complex64');
|
62 | expect(b.shape).toEqual([3]);
|
63 | expectArraysEqual(await b.data(), [1, 0, 1, 0, 1, 0]);
|
64 | });
|
65 | it('2D default dtype', async () => {
|
66 | const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
|
67 | const b = tf.onesLike(a);
|
68 | expect(b.dtype).toBe('float32');
|
69 | expect(b.shape).toEqual([2, 2]);
|
70 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
71 | });
|
72 | it('2D float32 dtype', async () => {
|
73 | const a = tf.tensor2d([1, 2, 3, 4], [2, 2], 'float32');
|
74 | const b = tf.onesLike(a);
|
75 | expect(b.dtype).toBe('float32');
|
76 | expect(b.shape).toEqual([2, 2]);
|
77 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
78 | });
|
79 | it('2D int32 dtype', async () => {
|
80 | const a = tf.tensor2d([1, 2, 3, 4], [2, 2], 'int32');
|
81 | const b = tf.onesLike(a);
|
82 | expect(b.dtype).toBe('int32');
|
83 | expect(b.shape).toEqual([2, 2]);
|
84 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
85 | });
|
86 | it('2D bool dtype', async () => {
|
87 | const a = tf.tensor2d([1, 2, 3, 4], [2, 2], 'bool');
|
88 | const b = tf.onesLike(a);
|
89 | expect(b.dtype).toBe('bool');
|
90 | expect(b.shape).toEqual([2, 2]);
|
91 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
92 | });
|
93 | it('2D complex dtype', async () => {
|
94 | const real = tf.tensor2d([1, 2, 3, 4], [2, 2], 'float32');
|
95 | const imag = tf.tensor2d([1, 2, 3, 4], [2, 2], 'float32');
|
96 | const a = tf.complex(real, imag);
|
97 | const b = tf.onesLike(a);
|
98 | expect(b.dtype).toBe('complex64');
|
99 | expect(b.shape).toEqual([2, 2]);
|
100 | expectArraysEqual(await b.data(), [1, 0, 1, 0, 1, 0, 1, 0]);
|
101 | });
|
102 | it('3D default dtype', async () => {
|
103 | const a = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
|
104 | const b = tf.onesLike(a);
|
105 | expect(b.dtype).toBe('float32');
|
106 | expect(b.shape).toEqual([2, 2, 1]);
|
107 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
108 | });
|
109 | it('3D float32 dtype', async () => {
|
110 | const a = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'float32');
|
111 | const b = tf.onesLike(a);
|
112 | expect(b.dtype).toBe('float32');
|
113 | expect(b.shape).toEqual([2, 2, 1]);
|
114 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
115 | });
|
116 | it('3D int32 dtype', async () => {
|
117 | const a = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'int32');
|
118 | const b = tf.onesLike(a);
|
119 | expect(b.dtype).toBe('int32');
|
120 | expect(b.shape).toEqual([2, 2, 1]);
|
121 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
122 | });
|
123 | it('3D bool dtype', async () => {
|
124 | const a = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'bool');
|
125 | const b = tf.onesLike(a);
|
126 | expect(b.dtype).toBe('bool');
|
127 | expect(b.shape).toEqual([2, 2, 1]);
|
128 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
129 | });
|
130 | it('3D complex dtype', async () => {
|
131 | const real = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'float32');
|
132 | const imag = tf.tensor3d([1, 2, 3, 4], [2, 2, 1], 'float32');
|
133 | const a = tf.complex(real, imag);
|
134 | const b = tf.onesLike(a);
|
135 | expect(b.dtype).toBe('complex64');
|
136 | expect(b.shape).toEqual([2, 2, 1]);
|
137 | expectArraysEqual(await b.data(), [1, 0, 1, 0, 1, 0, 1, 0]);
|
138 | });
|
139 | it('4D default dtype', async () => {
|
140 | const a = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1]);
|
141 | const b = tf.onesLike(a);
|
142 | expect(b.dtype).toBe('float32');
|
143 | expect(b.shape).toEqual([2, 2, 1, 1]);
|
144 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
145 | });
|
146 | it('4D float32 dtype', async () => {
|
147 | const a = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'float32');
|
148 | const b = tf.onesLike(a);
|
149 | expect(b.dtype).toBe('float32');
|
150 | expect(b.shape).toEqual([2, 2, 1, 1]);
|
151 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
152 | });
|
153 | it('4D int32 dtype', async () => {
|
154 | const a = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'int32');
|
155 | const b = tf.onesLike(a);
|
156 | expect(b.dtype).toBe('int32');
|
157 | expect(b.shape).toEqual([2, 2, 1, 1]);
|
158 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
159 | });
|
160 | it('4D bool dtype', async () => {
|
161 | const a = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'bool');
|
162 | const b = tf.onesLike(a);
|
163 | expect(b.dtype).toBe('bool');
|
164 | expect(b.shape).toEqual([2, 2, 1, 1]);
|
165 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
166 | });
|
167 | it('4D default dtype', async () => {
|
168 | const a = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1]);
|
169 | const b = tf.onesLike(a);
|
170 | expect(b.dtype).toBe('float32');
|
171 | expect(b.shape).toEqual([2, 2, 1, 1]);
|
172 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
173 | });
|
174 | it('4D complex dtype', async () => {
|
175 | const real = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'float32');
|
176 | const imag = tf.tensor4d([1, 2, 3, 4], [2, 2, 1, 1], 'float32');
|
177 | const a = tf.complex(real, imag);
|
178 | const b = tf.onesLike(a);
|
179 | expect(b.dtype).toBe('complex64');
|
180 | expect(b.shape).toEqual([2, 2, 1, 1]);
|
181 | expectArraysEqual(await b.data(), [1, 0, 1, 0, 1, 0, 1, 0]);
|
182 | });
|
183 | it('5D float32 dtype', async () => {
|
184 | const a = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'float32');
|
185 | const b = tf.onesLike(a);
|
186 | expect(b.dtype).toBe('float32');
|
187 | expect(b.shape).toEqual([1, 2, 2, 1, 1]);
|
188 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
189 | });
|
190 | it('5D int32 dtype', async () => {
|
191 | const a = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'int32');
|
192 | const b = tf.onesLike(a);
|
193 | expect(b.dtype).toBe('int32');
|
194 | expect(b.shape).toEqual([1, 2, 2, 1, 1]);
|
195 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
196 | });
|
197 | it('5D bool dtype', async () => {
|
198 | const a = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'bool');
|
199 | const b = tf.onesLike(a);
|
200 | expect(b.dtype).toBe('bool');
|
201 | expect(b.shape).toEqual([1, 2, 2, 1, 1]);
|
202 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
203 | });
|
204 | it('5D default dtype', async () => {
|
205 | const a = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1]);
|
206 | const b = tf.onesLike(a);
|
207 | expect(b.dtype).toBe('float32');
|
208 | expect(b.shape).toEqual([1, 2, 2, 1, 1]);
|
209 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
210 | });
|
211 | it('5D complex dtype', async () => {
|
212 | const real = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'float32');
|
213 | const imag = tf.tensor5d([1, 2, 3, 4], [1, 2, 2, 1, 1], 'float32');
|
214 | const a = tf.complex(real, imag);
|
215 | const b = tf.onesLike(a);
|
216 | expect(b.dtype).toBe('complex64');
|
217 | expect(b.shape).toEqual([1, 2, 2, 1, 1]);
|
218 | expectArraysEqual(await b.data(), [1, 0, 1, 0, 1, 0, 1, 0]);
|
219 | });
|
220 | it('6D int32 dtype', async () => {
|
221 | const a = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'int32');
|
222 | const b = tf.onesLike(a);
|
223 | expect(b.dtype).toBe('int32');
|
224 | expect(b.shape).toEqual(a.shape);
|
225 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
226 | });
|
227 | it('6D bool dtype', async () => {
|
228 | const a = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'bool');
|
229 | const b = tf.onesLike(a);
|
230 | expect(b.dtype).toBe('bool');
|
231 | expect(b.shape).toEqual(a.shape);
|
232 | expectArraysEqual(await b.data(), [1, 1, 1, 1]);
|
233 | });
|
234 | it('6D default dtype', async () => {
|
235 | const a = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1]);
|
236 | const b = tf.onesLike(a);
|
237 | expect(b.dtype).toBe('float32');
|
238 | expect(b.shape).toEqual(a.shape);
|
239 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
240 | });
|
241 | it('6D float32 dtype', async () => {
|
242 | const a = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'float32');
|
243 | const b = tf.onesLike(a);
|
244 | expect(b.dtype).toBe('float32');
|
245 | expect(b.shape).toEqual(a.shape);
|
246 | expectArraysClose(await b.data(), [1, 1, 1, 1]);
|
247 | });
|
248 | it('6D complex dtype', async () => {
|
249 | const real = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'float32');
|
250 | const imag = tf.tensor6d([1, 2, 3, 4], [1, 2, 2, 1, 1, 1], 'float32');
|
251 | const a = tf.complex(real, imag);
|
252 | const b = tf.onesLike(a);
|
253 | expect(b.dtype).toBe('complex64');
|
254 | expect(b.shape).toEqual([1, 2, 2, 1, 1, 1]);
|
255 | expectArraysEqual(await b.data(), [1, 0, 1, 0, 1, 0, 1, 0]);
|
256 | });
|
257 | it('throws when passed a non-tensor', () => {
|
258 | expect(() => tf.onesLike({}))
|
259 | .toThrowError(/Argument 'x' passed to 'onesLike' must be a Tensor/);
|
260 | });
|
261 | it('onesLike gradient', async () => {
|
262 | const x = tf.tensor2d([[0, 1, 2], [4, 5, 6]]);
|
263 | const gradients = tf.grad(x => tf.onesLike(x))(x);
|
264 | expect(gradients.shape).toEqual([2, 3]);
|
265 | expectArraysEqual(await gradients.data(), [0, 0, 0, 0, 0, 0]);
|
266 | });
|
267 | it('accepts a tensor-like object', async () => {
|
268 | const res = tf.onesLike([[1, 2], [3, 4]]);
|
269 | expect(res.shape).toEqual([2, 2]);
|
270 | expectArraysEqual(await res.data(), [1, 1, 1, 1]);
|
271 | });
|
272 | });
|
273 |
|
\ | No newline at end of file |