UNPKG

3.21 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import * as tf from '../index';
18import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
19import { expectArraysClose } from '../test_util';
20describeWithFlags('conv3dTranspose', ALL_ENVS, () => {
21 // Reference Python TensorFlow code
22 // ```python
23 // import numpy as np
24 // import tensorflow as tf
25 // tf.enable_eager_execution()
26 // x = np.array([2], dtype = np.float32).reshape(1, 1, 1, 1, 1)
27 // w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2, 2, 2,
28 // 1, 1)
29 // tf.nn.conv3d_transpose(x, w, output_shape=[1, 2, 2, 2, 1], padding='VALID')
30 // ```
31 it('input=2x2x2x1,d2=1,f=2,s=1,p=valid', async () => {
32 const origInputDepth = 1;
33 const origOutputDepth = 1;
34 const inputShape = [1, 1, 1, origOutputDepth];
35 const fSize = 2;
36 const origPad = 'valid';
37 const origStride = 1;
38 const x = tf.tensor4d([2], inputShape);
39 const w = tf.tensor5d([5, 4, 8, 7, 1, 2, 6, 3], [fSize, fSize, fSize, origInputDepth, origOutputDepth]);
40 const result = tf.conv3dTranspose(x, w, [2, 2, 2, 1], origStride, origPad);
41 const expected = [10, 8, 16, 14, 2, 4, 12, 6];
42 expect(result.shape).toEqual([2, 2, 2, 1]);
43 expectArraysClose(await result.data(), expected);
44 });
45 // Reference Python TensorFlow code
46 // ```python
47 // import numpy as np
48 // import tensorflow as tf
49 // tf.enable_eager_execution()
50 // x = np.array([2, 3], dtype = np.float32).reshape(2, 1, 1, 1, 1, 1)
51 // w = np.array([5, 4, 8, 7, 1, 2, 6, 3], dtype = np.float32).reshape(2,
52 // 2, 2, 1, 1)
53 // tf.nn.conv3d_transpose(x, w, output_shape=[2, 2, 2, 2, 1], padding='VALID')
54 // ```
55 it('input=2x2x2x1,d2=1,f=2,s=1,p=valid, batch=2', async () => {
56 const origInputDepth = 1;
57 const origOutputDepth = 1;
58 const inputShape = [2, 1, 1, 1, origOutputDepth];
59 const fSize = 2;
60 const origPad = 'valid';
61 const origStride = 1;
62 const x = tf.tensor5d([2, 3], inputShape);
63 const w = tf.tensor5d([5, 4, 8, 7, 1, 2, 6, 3], [fSize, fSize, fSize, origInputDepth, origOutputDepth]);
64 const result = tf.conv3dTranspose(x, w, [2, 2, 2, 2, 1], origStride, origPad);
65 const expected = [10, 8, 16, 14, 2, 4, 12, 6, 15, 12, 24, 21, 3, 6, 18, 9];
66 expect(result.shape).toEqual([2, 2, 2, 2, 1]);
67 expectArraysClose(await result.data(), expected);
68 });
69});
70//# sourceMappingURL=conv3d_transpose_test.js.map
\No newline at end of file