1 | /**
|
2 | * @license
|
3 | * Copyright 2017 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import * as tf from '../index';
|
18 | import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
19 | import { expectArraysClose } from '../test_util';
|
20 | describeWithFlags('conv3dTranspose', ALL_ENVS, () => {
|
21 | // Reference Python TensorFlow code
|
22 | // ```python
|
23 | // import numpy as np
|
24 | // import tensorflow as tf
|
25 | // tf.enable_eager_execution()
|
26 | // x = np.array([2], dtype = np.float32).reshape(1, 1, 1, 1, 1)
|
27 | // w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2, 2, 2,
|
28 | // 1, 1)
|
29 | // tf.nn.conv3d_transpose(x, w, output_shape=[1, 2, 2, 2, 1], padding='VALID')
|
30 | // ```
|
31 | it('input=2x2x2x1,d2=1,f=2,s=1,p=valid', async () => {
|
32 | const origInputDepth = 1;
|
33 | const origOutputDepth = 1;
|
34 | const inputShape = [1, 1, 1, origOutputDepth];
|
35 | const fSize = 2;
|
36 | const origPad = 'valid';
|
37 | const origStride = 1;
|
38 | const x = tf.tensor4d([2], inputShape);
|
39 | const w = tf.tensor5d([5, 4, 8, 7, 1, 2, 6, 3], [fSize, fSize, fSize, origInputDepth, origOutputDepth]);
|
40 | const result = tf.conv3dTranspose(x, w, [2, 2, 2, 1], origStride, origPad);
|
41 | const expected = [10, 8, 16, 14, 2, 4, 12, 6];
|
42 | expect(result.shape).toEqual([2, 2, 2, 1]);
|
43 | expectArraysClose(await result.data(), expected);
|
44 | });
|
45 | // Reference Python TensorFlow code
|
46 | // ```python
|
47 | // import numpy as np
|
48 | // import tensorflow as tf
|
49 | // tf.enable_eager_execution()
|
50 | // x = np.array([2, 3], dtype = np.float32).reshape(2, 1, 1, 1, 1, 1)
|
51 | // w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2,
|
52 | // 2, 2, 1, 1)
|
53 | // tf.nn.conv3d_transpose(x, w, output_shape=[2, 2, 2, 2, 1], padding='VALID')
|
54 | // ```
|
55 | it('input=2x2x2x1,d2=1,f=2,s=1,p=valid, batch=2', async () => {
|
56 | const origInputDepth = 1;
|
57 | const origOutputDepth = 1;
|
58 | const inputShape = [2, 1, 1, 1, origOutputDepth];
|
59 | const fSize = 2;
|
60 | const origPad = 'valid';
|
61 | const origStride = 1;
|
62 | const x = tf.tensor5d([2, 3], inputShape);
|
63 | const w = tf.tensor5d([5, 4, 8, 7, 1, 2, 6, 3], [fSize, fSize, fSize, origInputDepth, origOutputDepth]);
|
64 | const result = tf.conv3dTranspose(x, w, [2, 2, 2, 2, 1], origStride, origPad);
|
65 | const expected = [10, 8, 16, 14, 2, 4, 12, 6, 15, 12, 24, 21, 3, 6, 18, 9];
|
66 | expect(result.shape).toEqual([2, 2, 2, 2, 1]);
|
67 | expectArraysClose(await result.data(), expected);
|
68 | });
|
69 | });
|
70 | //# sourceMappingURL=conv3d_transpose_test.js.map |
\ | No newline at end of file |