UNPKG

12.5 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2017 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { add } from './engine';
18import * as tf from './index';
19import { ALL_ENVS, describeWithFlags } from './jasmine_util';
20import { zerosLike } from './ops/ops';
21import { backpropagateGradients, getFilteredNodesXToY } from './tape';
22import { expectArraysClose } from './test_util';
23describeWithFlags('getFilteredNodesXToY', ALL_ENVS, () => {
24 it('no paths from x to y', () => {
25 const x = tf.scalar(1);
26 const intermediate1 = tf.scalar(0);
27 const intermediate2 = tf.scalar(0);
28 const y = tf.scalar(2);
29 const tape = [
30 {
31 id: 0,
32 kernelName: 'node0',
33 inputs: { x },
34 outputs: [intermediate1],
35 gradient: null
36 },
37 {
38 id: 1,
39 kernelName: 'node1',
40 inputs: { intermediate2 },
41 outputs: [y],
42 gradient: null
43 }
44 ];
45 const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
46 expect(filteredTapeNodes.length).toBe(0);
47 expect(filteredTapeNodes).toEqual([]);
48 });
49 it('one operation x => y', () => {
50 const x = tf.scalar(1);
51 const y = tf.scalar(2);
52 const tape = [
53 { id: 0, kernelName: 'node0', inputs: { x }, outputs: [y], gradient: null }
54 ];
55 const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
56 expect(filteredTapeNodes.length).toBe(1);
57 expect(filteredTapeNodes).toEqual(tape);
58 });
59 it('1 operation [x0, x1] => y, all input paths', () => {
60 const x0 = tf.scalar(0);
61 const x1 = tf.scalar(1);
62 const y = tf.scalar(2);
63 const tape = [{
64 id: 0,
65 kernelName: 'node0',
66 inputs: { x0, x1 },
67 outputs: [y],
68 gradient: null
69 }];
70 const filteredTapeNodes = getFilteredNodesXToY(tape, [x0, x1], y);
71 expect(filteredTapeNodes.length).toBe(1);
72 expect(filteredTapeNodes).toEqual(tape);
73 });
74 it('one operation [x0, x1] => y, one input paths', () => {
75 const x0 = tf.scalar(0);
76 const x1 = tf.scalar(1);
77 const y = tf.scalar(2);
78 const tape = [{
79 id: 0,
80 kernelName: 'node0',
81 inputs: { x0, x1 },
82 outputs: [y],
83 gradient: null
84 }];
85 const filteredTapeNodes = getFilteredNodesXToY(tape, [x0], y);
86 expect(filteredTapeNodes.length).toBe(1);
87 // x1 input should be pruned, we don't ask for the gradient of x1.
88 expect(filteredTapeNodes[0]).toEqual({
89 id: 0,
90 kernelName: 'node0',
91 inputs: { x0 },
92 outputs: [y],
93 gradient: null
94 });
95 });
96 it('two operations x => intermediate => y', () => {
97 const x = tf.scalar(1);
98 const intermediate = tf.scalar(0);
99 const y = tf.scalar(2);
100 const tape = [
101 {
102 id: 0,
103 kernelName: 'node0',
104 inputs: { x },
105 outputs: [intermediate],
106 gradient: null
107 },
108 {
109 id: 1,
110 kernelName: 'node1',
111 inputs: { intermediate },
112 outputs: [y],
113 gradient: null
114 }
115 ];
116 const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
117 expect(filteredTapeNodes.length).toBe(2);
118 expect(filteredTapeNodes).toEqual(tape);
119 });
120 it('two operations [x0, x1], [x2] => ' +
121 'intermediate => y', () => {
122 const x0 = tf.scalar(1);
123 const x1 = tf.scalar(2);
124 const x2 = tf.scalar(3);
125 const intermediate = tf.scalar(4);
126 const y = tf.scalar(2);
127 const tape = [
128 {
129 id: 0,
130 kernelName: 'node0',
131 inputs: { x0, x1 },
132 outputs: [intermediate],
133 gradient: null
134 },
135 {
136 id: 1,
137 kernelName: 'node1',
138 inputs: { x2, intermediate },
139 outputs: [y],
140 gradient: null
141 }
142 ];
143 const filteredTapeNodes = getFilteredNodesXToY(tape, [x0, x1, x2], y);
144 expect(filteredTapeNodes.length).toBe(2);
145 expect(filteredTapeNodes).toEqual(tape);
146 });
147 it('x => y and x => orphan', () => {
148 const x = tf.scalar(1);
149 const orphan = tf.scalar(0);
150 const y = tf.scalar(2);
151 const tape = [
152 {
153 id: 0,
154 kernelName: 'node0',
155 inputs: { x },
156 outputs: [orphan],
157 gradient: null
158 },
159 { id: 1, kernelName: 'node1', inputs: { x }, outputs: [y], gradient: null }
160 ];
161 const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
162 expect(filteredTapeNodes.length).toBe(1);
163 // The orphan should be removed.
164 expect(filteredTapeNodes[0]).toEqual(tape[1]);
165 });
166 it('x => y and orphan => y', () => {
167 const x = tf.scalar(1);
168 const orphan = tf.scalar(0);
169 const y = tf.scalar(2);
170 const tape = [{
171 id: 0,
172 kernelName: 'node0',
173 inputs: { x, orphan },
174 outputs: [y],
175 gradient: null
176 }];
177 const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
178 expect(filteredTapeNodes.length).toBe(1);
179 // The orphan should be pruned from the node's input.
180 expect(filteredTapeNodes[0]).toEqual({
181 id: 0,
182 kernelName: 'node0',
183 inputs: { x },
184 outputs: [y],
185 gradient: null
186 });
187 });
188 it('1 op with 3 outputs x => y1, y2, y3', () => {
189 const x = tf.scalar(1);
190 const y1 = tf.scalar(2);
191 const y2 = tf.scalar(2);
192 const y3 = tf.scalar(2);
193 const tape = [{
194 id: 0,
195 kernelName: 'node0',
196 inputs: { x },
197 outputs: [y1, y2, y3],
198 gradient: null
199 }];
200 const filteredNodes1 = getFilteredNodesXToY(tape, [x], y1);
201 expect(filteredNodes1.length).toBe(1);
202 expect(filteredNodes1).toEqual(tape);
203 const filteredNodes2 = getFilteredNodesXToY(tape, [x], y2);
204 expect(filteredNodes2.length).toBe(1);
205 expect(filteredNodes2).toEqual(tape);
206 const filteredNodes3 = getFilteredNodesXToY(tape, [x], y3);
207 expect(filteredNodes3.length).toBe(1);
208 expect(filteredNodes3).toEqual(tape);
209 });
210});
211describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
212 it('Throws if gradient is not defined', () => {
213 const x = tf.scalar(0);
214 const y = tf.scalar(1);
215 const dy = tf.scalar(1);
216 const accumulatedGradientsMap = {};
217 accumulatedGradientsMap[y.id] = dy;
218 const tape = [
219 { id: 0, kernelName: 'node0', inputs: { x }, outputs: [y], gradient: null }
220 ];
221 expect(() => backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add))
222 .toThrowError();
223 });
224 it('basic backprop with 1 node', async () => {
225 const x = tf.scalar(0);
226 const y = tf.scalar(1);
227 const dy = tf.scalar(1);
228 const accumulatedGradientsMap = {};
229 accumulatedGradientsMap[y.id] = dy;
230 const tape = [{
231 id: 0,
232 kernelName: 'node0',
233 inputs: { x },
234 outputs: [y],
235 gradient: (dys) => {
236 return { x: () => dys[0].add(tf.scalar(1)) };
237 }
238 }];
239 backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
240 expectArraysClose(await accumulatedGradientsMap[x.id].data(), [2]);
241 });
242 it('basic backprop with 2 nodes', async () => {
243 const x = tf.scalar(0);
244 const intermediate = tf.scalar(1);
245 const y = tf.scalar(2);
246 const dy = tf.scalar(1);
247 const accumulatedGradientsMap = {};
248 accumulatedGradientsMap[y.id] = dy;
249 const tape = [
250 {
251 id: 0,
252 kernelName: 'node0',
253 inputs: { x },
254 outputs: [intermediate],
255 gradient: (dys) => {
256 return { x: () => dys[0].add(tf.scalar(1)) };
257 }
258 },
259 {
260 id: 1,
261 kernelName: 'node1',
262 inputs: { intermediate },
263 outputs: [y],
264 gradient: (dys) => {
265 return { intermediate: () => dys[0].add(tf.scalar(1)) };
266 }
267 }
268 ];
269 backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
270 // dx = dy + 1 + 1
271 expectArraysClose(await accumulatedGradientsMap[x.id].data(), [3]);
272 });
273 it('basic backprop with a split node accumulates gradients', async () => {
274 const x = tf.scalar(0);
275 const intermediate1 = tf.scalar(1);
276 const intermediate2 = tf.scalar(2);
277 const y = tf.scalar(3);
278 const dy = tf.scalar(1);
279 const accumulatedGradientsMap = {};
280 accumulatedGradientsMap[y.id] = dy;
281 const tape = [
282 {
283 id: 0,
284 kernelName: 'node0',
285 inputs: { x },
286 outputs: [intermediate1],
287 gradient: (dys) => {
288 return { x: () => dys[0].add(tf.scalar(1)) };
289 }
290 },
291 {
292 id: 1,
293 kernelName: 'node1',
294 inputs: { x },
295 outputs: [intermediate2],
296 gradient: (dys) => {
297 return { x: () => dys[0].add(tf.scalar(1)) };
298 }
299 },
300 {
301 id: 2,
302 kernelName: 'node2',
303 inputs: { intermediate1, intermediate2 },
304 outputs: [y],
305 gradient: (dys) => {
306 return {
307 intermediate1: () => dys[0].add(tf.scalar(1)),
308 intermediate2: () => dys[0].add(tf.scalar(1))
309 };
310 }
311 }
312 ];
313 backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
314 // dx = dy + 1 + 1 + 1 + 1 + 1
315 expectArraysClose(await accumulatedGradientsMap[x.id].data(), [(await dy.data())[0] + 5]);
316 });
317 it('backprop over 1 node with 3 outputs, w.r.t to the 2nd output', async () => {
318 const x = tf.tensor1d([1, 1, 1]);
319 const y1 = tf.scalar(1);
320 const y2 = tf.scalar(1);
321 const y3 = tf.scalar(1);
322 const accumulatedGradientsMap = {};
323 // Backproping through the 2nd output.
324 const dy2 = tf.scalar(5);
325 accumulatedGradientsMap[y2.id] = dy2;
326 let dys;
327 const tape = [{
328 id: 0,
329 kernelName: 'node0',
330 inputs: { x },
331 outputs: [y1, y2, y3],
332 gradient: (dys_) => {
333 dys = dys_.map(dy => dy || zerosLike(y1));
334 return { x: () => tf.stack(dys) };
335 }
336 }];
337 backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
338 expectArraysClose(await accumulatedGradientsMap[x.id].data(), [0, 5, 0]);
339 expectArraysClose(await dys[0].data(), [0]);
340 expectArraysClose(await dys[1].data(), [5]);
341 expectArraysClose(await dys[2].data(), [0]);
342 });
343});
344//# sourceMappingURL=tape_test.js.map
\No newline at end of file