UNPKG

20.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('stack', ALL_ENVS, () => {
21 it('scalars 3, 5 and 7', async () => {
22 const a = tf.scalar(3);
23 const b = tf.scalar(5);
24 const c = tf.scalar(7);
25 const res = tf.stack([a, b, c]);
26 expect(res.shape).toEqual([3]);
27 expectArraysClose(await res.data(), [3, 5, 7]);
28 });
29 it('scalars 3, 5 and 7 along axis=1 throws error', () => {
30 const a = tf.scalar(3);
31 const b = tf.scalar(5);
32 const c = tf.scalar(7);
33 const f = () => tf.stack([a, b, c], 1);
34 expect(f).toThrowError();
35 });
36 it('non matching shapes throws error', () => {
37 const a = tf.scalar(3);
38 const b = tf.tensor1d([5]);
39 const f = () => tf.stack([a, b]);
40 expect(f).toThrowError();
41 });
42 it('non matching dtypes throws error', () => {
43 const a = tf.scalar(3);
44 const b = tf.scalar(5, 'bool');
45 const f = () => tf.stack([a, b]);
46 expect(f).toThrowError();
47 });
48 it('2d but axis=3 throws error', () => {
49 const a = tf.zeros([2, 2]);
50 const b = tf.zeros([2, 2]);
51 const f = () => tf.stack([a, b], 3 /* axis */);
52 expect(f).toThrowError();
53 });
54 it('[1,2], [3,4] and [5,6], axis=0', async () => {
55 const a = tf.tensor1d([1, 2]);
56 const b = tf.tensor1d([3, 4]);
57 const c = tf.tensor1d([5, 6]);
58 const res = tf.stack([a, b, c], 0 /* axis */);
59 expect(res.shape).toEqual([3, 2]);
60 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
61 });
62 it('[1,2], [3,4] and [5,6], axis=1', async () => {
63 const a = tf.tensor1d([1, 2]);
64 const b = tf.tensor1d([3, 4]);
65 const c = tf.tensor1d([5, 6]);
66 const res = tf.stack([a, b, c], 1 /* axis */);
67 expect(res.shape).toEqual([2, 3]);
68 expectArraysClose(await res.data(), [1, 3, 5, 2, 4, 6]);
69 });
70 it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=0', async () => {
71 const a = tf.tensor2d([[1, 2], [3, 4]]);
72 const b = tf.tensor2d([[5, 6], [7, 8]]);
73 const res = tf.stack([a, b], 0 /* axis */);
74 expect(res.shape).toEqual([2, 2, 2]);
75 expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6, 7, 8]);
76 });
77 it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=2', async () => {
78 const a = tf.tensor2d([[1, 2], [3, 4]]);
79 const b = tf.tensor2d([[5, 6], [7, 8]]);
80 const c = tf.tensor2d([[9, 10], [11, 12]]);
81 const res = tf.stack([a, b, c], 2 /* axis */);
82 expect(res.shape).toEqual([2, 2, 3]);
83 expectArraysClose(await res.data(), [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12]);
84 });
85 it('single tensor', async () => {
86 const a = tf.tensor2d([[1, 2], [3, 4]]);
87 const res = tf.stack([a], 2 /* axis */);
88 expect(res.shape).toEqual([2, 2, 1]);
89 expectArraysClose(await res.data(), [1, 2, 3, 4]);
90 });
91 it('throws when passed a non-tensor', () => {
92 expect(() => tf.stack([{}]))
93 .toThrowError(/Argument 'tensors\[0\]' passed to 'stack' must be a Tensor/);
94 });
95 it('accepts a tensor-like object', async () => {
96 const a = [[1, 2], [3, 4]];
97 const res = tf.stack([a], 2 /* axis */);
98 expect(res.shape).toEqual([2, 2, 1]);
99 expectArraysClose(await res.data(), [1, 2, 3, 4]);
100 });
101 it('accepts string.', async () => {
102 const a = tf.scalar('three', 'string');
103 const b = tf.scalar('five', 'string');
104 const c = tf.scalar('seven', 'string');
105 const res = tf.stack([a, b, c]);
106 expect(res.shape).toEqual([3]);
107 expectArraysClose(await res.data(), ['three', 'five', 'seven']);
108 });
109 it('chain api', async () => {
110 const a = tf.tensor([1, 2]);
111 const res = a.stack(tf.tensor([3, 4]));
112 expect(res.shape).toEqual([2, 2]);
113 expectArraysClose(await res.data(), [1, 2, 3, 4]);
114 });
115});
116//# sourceMappingURL=data:application/json;base64,
\No newline at end of file