1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import { ENGINE } from '../engine';
|
18 | import { Conv3DBackpropInputV2 } from '../kernel_names';
|
19 | import * as util from '../util';
|
20 | import { op } from './operation';
|
21 | import { reshape } from './reshape';
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 |
|
29 |
|
30 |
|
31 |
|
32 |
|
33 |
|
34 |
|
35 |
|
36 |
|
37 |
|
38 |
|
39 |
|
40 | function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
|
41 | util.assert(xShape.length === dy.rank, () => `Length of inShape ` +
|
42 | `(${xShape.length}) and rank of dy (${dy.rank}) must match`);
|
43 | let xShape5D = xShape;
|
44 | let dy5D = dy;
|
45 | let reshapedTo5D = false;
|
46 | if (dy.rank === 4) {
|
47 | reshapedTo5D = true;
|
48 | dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
|
49 | xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
|
50 | }
|
51 | const inDepth = xShape5D[4];
|
52 | const outDepth = dy5D.shape[4];
|
53 | util.assert(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
|
54 | `${xShape5D.length}.`);
|
55 | util.assert(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
|
56 | `rank ${dy5D.rank}`);
|
57 | util.assert(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
|
58 | `rank ${filter.rank}`);
|
59 | util.assert(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
|
60 | `match input depth for filter ${filter.shape[3]}.`);
|
61 | util.assert(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
|
62 | `match output depth for filter ${filter.shape[4]}.`);
|
63 | const inputs = { dy: dy5D, filter };
|
64 | const attrs = { pad, strides, inputShape: xShape5D };
|
65 |
|
66 | const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
|
67 | if (reshapedTo5D) {
|
68 | return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
69 | }
|
70 | return res;
|
71 | }
|
72 | export const conv3DBackpropInput = op({ conv3DBackpropInput_ });
|
73 |
|
\ | No newline at end of file |