1 |
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
|
16 |
|
17 | import { add } from './engine';
|
18 | import * as tf from './index';
|
19 | import { ALL_ENVS, describeWithFlags } from './jasmine_util';
|
20 | import { zerosLike } from './ops/ops';
|
21 | import { backpropagateGradients, getFilteredNodesXToY } from './tape';
|
22 | import { expectArraysClose } from './test_util';
|
23 | describeWithFlags('getFilteredNodesXToY', ALL_ENVS, () => {
|
24 | it('no paths from x to y', () => {
|
25 | const x = tf.scalar(1);
|
26 | const intermediate1 = tf.scalar(0);
|
27 | const intermediate2 = tf.scalar(0);
|
28 | const y = tf.scalar(2);
|
29 | const tape = [
|
30 | {
|
31 | id: 0,
|
32 | kernelName: 'node0',
|
33 | inputs: { x },
|
34 | outputs: [intermediate1],
|
35 | gradient: null
|
36 | },
|
37 | {
|
38 | id: 1,
|
39 | kernelName: 'node1',
|
40 | inputs: { intermediate2 },
|
41 | outputs: [y],
|
42 | gradient: null
|
43 | }
|
44 | ];
|
45 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
|
46 | expect(filteredTapeNodes.length).toBe(0);
|
47 | expect(filteredTapeNodes).toEqual([]);
|
48 | });
|
49 | it('one operation x => y', () => {
|
50 | const x = tf.scalar(1);
|
51 | const y = tf.scalar(2);
|
52 | const tape = [
|
53 | { id: 0, kernelName: 'node0', inputs: { x }, outputs: [y], gradient: null }
|
54 | ];
|
55 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
|
56 | expect(filteredTapeNodes.length).toBe(1);
|
57 | expect(filteredTapeNodes).toEqual(tape);
|
58 | });
|
59 | it('1 operation [x0, x1] => y, all input paths', () => {
|
60 | const x0 = tf.scalar(0);
|
61 | const x1 = tf.scalar(1);
|
62 | const y = tf.scalar(2);
|
63 | const tape = [{
|
64 | id: 0,
|
65 | kernelName: 'node0',
|
66 | inputs: { x0, x1 },
|
67 | outputs: [y],
|
68 | gradient: null
|
69 | }];
|
70 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x0, x1], y);
|
71 | expect(filteredTapeNodes.length).toBe(1);
|
72 | expect(filteredTapeNodes).toEqual(tape);
|
73 | });
|
74 | it('one operation [x0, x1] => y, one input paths', () => {
|
75 | const x0 = tf.scalar(0);
|
76 | const x1 = tf.scalar(1);
|
77 | const y = tf.scalar(2);
|
78 | const tape = [{
|
79 | id: 0,
|
80 | kernelName: 'node0',
|
81 | inputs: { x0, x1 },
|
82 | outputs: [y],
|
83 | gradient: null
|
84 | }];
|
85 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x0], y);
|
86 | expect(filteredTapeNodes.length).toBe(1);
|
87 |
|
88 | expect(filteredTapeNodes[0]).toEqual({
|
89 | id: 0,
|
90 | kernelName: 'node0',
|
91 | inputs: { x0 },
|
92 | outputs: [y],
|
93 | gradient: null
|
94 | });
|
95 | });
|
96 | it('two operations x => intermediate => y', () => {
|
97 | const x = tf.scalar(1);
|
98 | const intermediate = tf.scalar(0);
|
99 | const y = tf.scalar(2);
|
100 | const tape = [
|
101 | {
|
102 | id: 0,
|
103 | kernelName: 'node0',
|
104 | inputs: { x },
|
105 | outputs: [intermediate],
|
106 | gradient: null
|
107 | },
|
108 | {
|
109 | id: 1,
|
110 | kernelName: 'node1',
|
111 | inputs: { intermediate },
|
112 | outputs: [y],
|
113 | gradient: null
|
114 | }
|
115 | ];
|
116 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
|
117 | expect(filteredTapeNodes.length).toBe(2);
|
118 | expect(filteredTapeNodes).toEqual(tape);
|
119 | });
|
120 | it('two operations [x0, x1], [x2] => ' +
|
121 | 'intermediate => y', () => {
|
122 | const x0 = tf.scalar(1);
|
123 | const x1 = tf.scalar(2);
|
124 | const x2 = tf.scalar(3);
|
125 | const intermediate = tf.scalar(4);
|
126 | const y = tf.scalar(2);
|
127 | const tape = [
|
128 | {
|
129 | id: 0,
|
130 | kernelName: 'node0',
|
131 | inputs: { x0, x1 },
|
132 | outputs: [intermediate],
|
133 | gradient: null
|
134 | },
|
135 | {
|
136 | id: 1,
|
137 | kernelName: 'node1',
|
138 | inputs: { x2, intermediate },
|
139 | outputs: [y],
|
140 | gradient: null
|
141 | }
|
142 | ];
|
143 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x0, x1, x2], y);
|
144 | expect(filteredTapeNodes.length).toBe(2);
|
145 | expect(filteredTapeNodes).toEqual(tape);
|
146 | });
|
147 | it('x => y and x => orphan', () => {
|
148 | const x = tf.scalar(1);
|
149 | const orphan = tf.scalar(0);
|
150 | const y = tf.scalar(2);
|
151 | const tape = [
|
152 | {
|
153 | id: 0,
|
154 | kernelName: 'node0',
|
155 | inputs: { x },
|
156 | outputs: [orphan],
|
157 | gradient: null
|
158 | },
|
159 | { id: 1, kernelName: 'node1', inputs: { x }, outputs: [y], gradient: null }
|
160 | ];
|
161 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
|
162 | expect(filteredTapeNodes.length).toBe(1);
|
163 |
|
164 | expect(filteredTapeNodes[0]).toEqual(tape[1]);
|
165 | });
|
166 | it('x => y and orphan => y', () => {
|
167 | const x = tf.scalar(1);
|
168 | const orphan = tf.scalar(0);
|
169 | const y = tf.scalar(2);
|
170 | const tape = [{
|
171 | id: 0,
|
172 | kernelName: 'node0',
|
173 | inputs: { x, orphan },
|
174 | outputs: [y],
|
175 | gradient: null
|
176 | }];
|
177 | const filteredTapeNodes = getFilteredNodesXToY(tape, [x], y);
|
178 | expect(filteredTapeNodes.length).toBe(1);
|
179 |
|
180 | expect(filteredTapeNodes[0]).toEqual({
|
181 | id: 0,
|
182 | kernelName: 'node0',
|
183 | inputs: { x },
|
184 | outputs: [y],
|
185 | gradient: null
|
186 | });
|
187 | });
|
188 | it('1 op with 3 outputs x => y1, y2, y3', () => {
|
189 | const x = tf.scalar(1);
|
190 | const y1 = tf.scalar(2);
|
191 | const y2 = tf.scalar(2);
|
192 | const y3 = tf.scalar(2);
|
193 | const tape = [{
|
194 | id: 0,
|
195 | kernelName: 'node0',
|
196 | inputs: { x },
|
197 | outputs: [y1, y2, y3],
|
198 | gradient: null
|
199 | }];
|
200 | const filteredNodes1 = getFilteredNodesXToY(tape, [x], y1);
|
201 | expect(filteredNodes1.length).toBe(1);
|
202 | expect(filteredNodes1).toEqual(tape);
|
203 | const filteredNodes2 = getFilteredNodesXToY(tape, [x], y2);
|
204 | expect(filteredNodes2.length).toBe(1);
|
205 | expect(filteredNodes2).toEqual(tape);
|
206 | const filteredNodes3 = getFilteredNodesXToY(tape, [x], y3);
|
207 | expect(filteredNodes3.length).toBe(1);
|
208 | expect(filteredNodes3).toEqual(tape);
|
209 | });
|
210 | });
|
211 | describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
|
212 | it('Throws if gradient is not defined', () => {
|
213 | const x = tf.scalar(0);
|
214 | const y = tf.scalar(1);
|
215 | const dy = tf.scalar(1);
|
216 | const accumulatedGradientsMap = {};
|
217 | accumulatedGradientsMap[y.id] = dy;
|
218 | const tape = [
|
219 | { id: 0, kernelName: 'node0', inputs: { x }, outputs: [y], gradient: null }
|
220 | ];
|
221 | expect(() => backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add))
|
222 | .toThrowError();
|
223 | });
|
224 | it('basic backprop with 1 node', async () => {
|
225 | const x = tf.scalar(0);
|
226 | const y = tf.scalar(1);
|
227 | const dy = tf.scalar(1);
|
228 | const accumulatedGradientsMap = {};
|
229 | accumulatedGradientsMap[y.id] = dy;
|
230 | const tape = [{
|
231 | id: 0,
|
232 | kernelName: 'node0',
|
233 | inputs: { x },
|
234 | outputs: [y],
|
235 | gradient: (dys) => {
|
236 | return { x: () => dys[0].add(tf.scalar(1)) };
|
237 | }
|
238 | }];
|
239 | backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
|
240 | expectArraysClose(await accumulatedGradientsMap[x.id].data(), [2]);
|
241 | });
|
242 | it('basic backprop with 2 nodes', async () => {
|
243 | const x = tf.scalar(0);
|
244 | const intermediate = tf.scalar(1);
|
245 | const y = tf.scalar(2);
|
246 | const dy = tf.scalar(1);
|
247 | const accumulatedGradientsMap = {};
|
248 | accumulatedGradientsMap[y.id] = dy;
|
249 | const tape = [
|
250 | {
|
251 | id: 0,
|
252 | kernelName: 'node0',
|
253 | inputs: { x },
|
254 | outputs: [intermediate],
|
255 | gradient: (dys) => {
|
256 | return { x: () => dys[0].add(tf.scalar(1)) };
|
257 | }
|
258 | },
|
259 | {
|
260 | id: 1,
|
261 | kernelName: 'node1',
|
262 | inputs: { intermediate },
|
263 | outputs: [y],
|
264 | gradient: (dys) => {
|
265 | return { intermediate: () => dys[0].add(tf.scalar(1)) };
|
266 | }
|
267 | }
|
268 | ];
|
269 | backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
|
270 |
|
271 | expectArraysClose(await accumulatedGradientsMap[x.id].data(), [3]);
|
272 | });
|
273 | it('basic backprop with a split node accumulates gradients', async () => {
|
274 | const x = tf.scalar(0);
|
275 | const intermediate1 = tf.scalar(1);
|
276 | const intermediate2 = tf.scalar(2);
|
277 | const y = tf.scalar(3);
|
278 | const dy = tf.scalar(1);
|
279 | const accumulatedGradientsMap = {};
|
280 | accumulatedGradientsMap[y.id] = dy;
|
281 | const tape = [
|
282 | {
|
283 | id: 0,
|
284 | kernelName: 'node0',
|
285 | inputs: { x },
|
286 | outputs: [intermediate1],
|
287 | gradient: (dys) => {
|
288 | return { x: () => dys[0].add(tf.scalar(1)) };
|
289 | }
|
290 | },
|
291 | {
|
292 | id: 1,
|
293 | kernelName: 'node1',
|
294 | inputs: { x },
|
295 | outputs: [intermediate2],
|
296 | gradient: (dys) => {
|
297 | return { x: () => dys[0].add(tf.scalar(1)) };
|
298 | }
|
299 | },
|
300 | {
|
301 | id: 2,
|
302 | kernelName: 'node2',
|
303 | inputs: { intermediate1, intermediate2 },
|
304 | outputs: [y],
|
305 | gradient: (dys) => {
|
306 | return {
|
307 | intermediate1: () => dys[0].add(tf.scalar(1)),
|
308 | intermediate2: () => dys[0].add(tf.scalar(1))
|
309 | };
|
310 | }
|
311 | }
|
312 | ];
|
313 | backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
|
314 |
|
315 | expectArraysClose(await accumulatedGradientsMap[x.id].data(), [(await dy.data())[0] + 5]);
|
316 | });
|
317 | it('backprop over 1 node with 3 outputs, w.r.t to the 2nd output', async () => {
|
318 | const x = tf.tensor1d([1, 1, 1]);
|
319 | const y1 = tf.scalar(1);
|
320 | const y2 = tf.scalar(1);
|
321 | const y3 = tf.scalar(1);
|
322 | const accumulatedGradientsMap = {};
|
323 |
|
324 | const dy2 = tf.scalar(5);
|
325 | accumulatedGradientsMap[y2.id] = dy2;
|
326 | let dys;
|
327 | const tape = [{
|
328 | id: 0,
|
329 | kernelName: 'node0',
|
330 | inputs: { x },
|
331 | outputs: [y1, y2, y3],
|
332 | gradient: (dys_) => {
|
333 | dys = dys_.map(dy => dy || zerosLike(y1));
|
334 | return { x: () => tf.stack(dys) };
|
335 | }
|
336 | }];
|
337 | backpropagateGradients(accumulatedGradientsMap, tape, f => tf.tidy(f), add);
|
338 | expectArraysClose(await accumulatedGradientsMap[x.id].data(), [0, 5, 0]);
|
339 | expectArraysClose(await dys[0].data(), [0]);
|
340 | expectArraysClose(await dys[1].data(), [5]);
|
341 | expectArraysClose(await dys[2].data(), [0]);
|
342 | });
|
343 | });
|
344 |
|
\ | No newline at end of file |