1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google Inc. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { customGrad } from '../gradients';
|
18 | import { convertToTensor } from '../tensor_util_env';
|
19 | import { cast } from './cast';
|
20 | import { exp } from './exp';
|
21 | import { log } from './log';
|
22 | import { max } from './max';
|
23 | import { mul } from './mul';
|
24 | import { op } from './operation';
|
25 | import { sub } from './sub';
|
26 | import { sum } from './sum';
|
27 | /**
|
28 | * Computes the log softmax.
|
29 | *
|
30 | * ```js
|
31 | * const a = tf.tensor1d([1, 2, 3]);
|
32 | *
|
33 | * a.logSoftmax().print(); // or tf.logSoftmax(a)
|
34 | * ```
|
35 | *
|
36 | * ```js
|
37 | * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
|
38 | *
|
39 | * a.logSoftmax().print(); // or tf.logSoftmax(a)
|
40 | * ```
|
41 | *
|
42 | * @param logits The logits array.
|
43 | * @param axis The dimension softmax would be performed on. Defaults to `-1`
|
44 | * which indicates the last dimension.
|
45 | *
|
46 | * @doc {heading: 'Operations', subheading: 'Normalization'}
|
47 | */
|
48 | function logSoftmax_(logits, axis = -1) {
|
49 | const $logits = convertToTensor(logits, 'logits', 'logSoftmax');
|
50 | if (axis === -1) {
|
51 | axis = $logits.rank - 1;
|
52 | }
|
53 | if (axis !== $logits.rank - 1) {
|
54 | throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
|
55 | `Logits was rank ${$logits.rank} and axis was ${axis}`);
|
56 | }
|
57 | // const forward: ForwardFunc<Tensor> = (backend, save) => {
|
58 | // const keepDims = true;
|
59 | // const xMax = max(logits, axis, true);
|
60 | // const shifted = sub(logits, xMax);
|
61 | // const value =
|
62 | // sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis,
|
63 | // keepDims)));
|
64 | // save([value]);
|
65 | // return value;
|
66 | // };
|
67 | // Use a custom gradient for numerical stability.
|
68 | const customOp = customGrad((logits, save) => {
|
69 | const keepDims = true;
|
70 | const xMax = max(logits, axis, true);
|
71 | const shifted = sub(logits, xMax);
|
72 | const value = sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis, keepDims)));
|
73 | save([value]);
|
74 | const gradFunc = (dy, saved) => {
|
75 | const [value] = saved;
|
76 | const keepDims = true;
|
77 | const softmax = exp(value);
|
78 | return sub(dy, mul(sum(dy, axis, keepDims), softmax));
|
79 | };
|
80 | return { value, gradFunc };
|
81 | });
|
82 | return customOp($logits);
|
83 | // TODO Use Engine.runKernel when CPU/WebGL/WASM backends implement this.
|
84 | // const inputs: LogSoftmaxInputs = {logits: $logits};
|
85 | // const attrs: LogSoftmaxAttrs = {axis};
|
86 | // return ENGINE.runKernel(
|
87 | // LogSoftmax, inputs as {} as NamedTensorMap,
|
88 | // attrs as {} as NamedAttrMap);
|
89 | }
|
90 | export const logSoftmax = op({ logSoftmax_ });
|
91 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibG9nX3NvZnRtYXguanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9sb2dfc29mdG1heC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsVUFBVSxFQUFDLE1BQU0sY0FBYyxDQUFDO0FBSXhDLE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUduRCxPQUFPLEVBQUMsSUFBSSxFQUFDLE1BQU0sUUFBUSxDQUFDO0FBQzVCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLEVBQUUsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUMvQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFFMUI7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7O0dBb0JHO0FBQ0gsU0FBUyxXQUFXLENBQW1CLE1BQW9CLEVBQUUsSUFBSSxHQUFHLENBQUMsQ0FBQztJQUNwRSxNQUFNLE9BQU8sR0FBRyxlQUFlLENBQUMsTUFBTSxFQUFFLFFBQVEsRUFBRSxZQUFZLENBQUMsQ0FBQztJQUVoRSxJQUFJLElBQUksS0FBSyxDQUFDLENBQUMsRUFBRTtRQUNmLElBQUksR0FBRyxPQUFPLENBQUMsSUFBSSxHQUFHLENBQUMsQ0FBQztLQUN6QjtJQUNELElBQUksSUFBSSxLQUFLLE9BQU8sQ0FBQyxJQUFJLEdBQUcsQ0FBQyxFQUFFO1FBQzdCLE1BQU0sS0FBSyxDQUNQLCtEQUErRDtZQUMvRCxtQkFBbUIsT0FBTyxDQUFDLElBQUksaUJBQWlCLElBQUksRUFBRSxDQUFDLENBQUM7S0FDN0Q7SUFFRCw0REFBNEQ7SUFDNUQsMkJBQTJCO0lBQzNCLDBDQUEwQztJQUMxQyx1Q0FBdUM7SUFDdkMsa0JBQWtCO0lBQ2xCLGtFQUFrRTtJQUNsRSxxQkFBcUI7SUFDckIsbUJBQW1CO0lBQ25CLGtCQUFrQjtJQUNsQixLQUFLO0lBRUwsaURBQWlEO0lBQ2pELE1BQU0sUUFBUSxHQUFHLFVBQVUsQ0FBQyxDQUFDLE1BQWMsRUFBRSxJQUFrQixFQUFFLEVBQUU7UUFDakUsTUFBTSxRQUFRLEdBQUcsSUFBSSxDQUFDO1FBQ3RCLE1BQU0sSUFBSSxHQUFHLEdBQUcsQ0FBQyxNQUFNLEVBQUUsSUFBSSxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQ3JDLE1BQU0sT0FBTyxHQUFHLEdBQUcsQ0FBQyxNQUFNLEVBQUUsSUFBSSxDQUFDLENBQUM7UUFDbEMsTUFBTSxLQUFLLEdBQ1AsR0FBRyxDQUFDLElBQUksQ0FBQyxPQUFPLEVBQUUsU0FBUyxDQUFDLEVBQUUsR0FBRyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLEVBQUUsSUFBSSxFQUFFLFFBQVEsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMxRSxJQUFJLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDO1FBRWQsTUFBTSxRQUFRLEdBQUcsQ0FBQyxFQUFVLEVBQUUsS0FBZSxFQUFFLEVBQUU7WUFDL0MsTUFBTSxDQUFDLEtBQUssQ0FBQyxHQUFHLEtBQUssQ0FBQztZQUN0QixNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUM7WUFDdEIsTUFBTSxPQUFPLEdBQUcsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDO1lBQzNCLE9BQU8sR0FBRyxDQUFDLEVBQUUsRUFBRSxHQUFHLENBQUMsR0FBRyxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsUUFBUSxDQUFDLEVBQUUsT0FBTyxDQUFDLENBQUMsQ0FBQztRQUN4RCxDQUFDLENBQUM7UUFDRixPQUFPLEVBQUMsS0FBSyxFQUFFLFFBQVEsRUFBQyxDQUFDO0lBQzNCLENBQUMsQ0FBQyxDQUFDO0lBRUgsT0FBTyxRQUFRLENBQUMsT0FBTyxDQUFNLENBQUM7SUFFOUIseUVBQXlFO0lBQ3pFLHNEQUFzRDtJQUN0RCx5Q0FBeUM7SUFDekMsMkJBQTJCO0lBQzNCLHlEQUF5RDtJQUN6RCwyQ0FBMkM7QUFDN0MsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLFVBQVUsR0FBRyxFQUFFLENBQUMsRUFBQyxXQUFXLEVBQUMsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIEluYy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2N1c3RvbUdyYWR9IGZyb20gJy4uL2dyYWRpZW50cyc7XG5cbmltcG9ydCB7VGVuc29yfSBmcm9tICcuLi90ZW5zb3InO1xuaW1wb3J0IHtHcmFkU2F2ZUZ1bmN9IGZyb20gJy4uL3RlbnNvcl90eXBlcyc7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvcn0gZnJvbSAnLi4vdGVuc29yX3V0aWxfZW52JztcbmltcG9ydCB7VGVuc29yTGlrZX0gZnJvbSAnLi4vdHlwZXMnO1xuXG5pbXBvcnQge2Nhc3R9IGZyb20gJy4vY2FzdCc7XG5pbXBvcnQge2V4cH0gZnJvbSAnLi9leHAnO1xuaW1wb3J0IHtsb2d9IGZyb20gJy4vbG9nJztcbmltcG9ydCB7bWF4fSBmcm9tICcuL21heCc7XG5pbXBvcnQge211bH0gZnJvbSAnLi9tdWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtzdWJ9IGZyb20gJy4vc3ViJztcbmltcG9ydCB7c3VtfSBmcm9tICcuL3N1bSc7XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIGxvZyBzb2Z0bWF4LlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCBhID0gdGYudGVuc29yMWQoWzEsIDIsIDNdKTtcbiAqXG4gKiBhLmxvZ1NvZnRtYXgoKS5wcmludCgpOyAgLy8gb3IgdGYubG9nU29mdG1heChhKVxuICogYGBgXG4gKlxuICogYGBganNcbiAqIGNvbnN0IGEgPSB0Zi50ZW5zb3IyZChbMiwgNCwgNiwgMSwgMiwgM10sIFsyLCAzXSk7XG4gKlxuICogYS5sb2dTb2Z0bWF4KCkucHJpbnQoKTsgIC8vIG9yIHRmLmxvZ1NvZnRtYXgoYSlcbiAqIGBgYFxuICpcbiAqIEBwYXJhbSBsb2dpdHMgVGhlIGxvZ2l0cyBhcnJheS5cbiAqIEBwYXJhbSBheGlzIFRoZSBkaW1lbnNpb24gc29mdG1heCB3b3VsZCBiZSBwZXJmb3JtZWQgb24uIERlZmF1bHRzIHRvIGAtMWBcbiAqICAgICB3aGljaCBpbmRpY2F0ZXMgdGhlIGxhc3QgZGltZW5zaW9uLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdPcGVyYXRpb25zJywgc3ViaGVhZGluZzogJ05vcm1hbGl6YXRpb24nfVxuICovXG5mdW5jdGlvbiBsb2dTb2Z0bWF4XzxUIGV4dGVuZHMgVGVuc29yPihsb2dpdHM6IFR8VGVuc29yTGlrZSwgYXhpcyA9IC0xKTogVCB7XG4gIGNvbnN0ICRsb2dpdHMgPSBjb252ZXJ0VG9UZW5zb3IobG9naXRzLCAnbG9naXRzJywgJ2xvZ1NvZnRtYXgnKTtcblxuICBpZiAoYXhpcyA9PT0gLTEpIHtcbiAgICBheGlzID0gJGxvZ2l0cy5yYW5rIC0gMTtcbiAgfVxuICBpZiAoYXhpcyAhPT0gJGxvZ2l0cy5yYW5rIC0gMSkge1xuICAgIHRocm93IEVycm9yKFxuICAgICAgICAnTG9nIFNvZnRtYXggYWxvbmcgYSBub24tbGFzdCBkaW1lbnNpb24gaXMgbm90IHlldCBzdXBwb3J0ZWQuICcgK1xuICAgICAgICBgTG9naXRzIHdhcyByYW5rICR7JGxvZ2l0cy5yYW5rfSBhbmQgYXhpcyB3YXMgJHtheGlzfWApO1xuICB9XG5cbiAgLy8gY29uc3QgZm9yd2FyZDogRm9yd2FyZEZ1bmM8VGVuc29yPiA9IChiYWNrZW5kLCBzYXZlKSA9PiB7XG4gIC8vICAgY29uc3Qga2VlcERpbXMgPSB0cnVlO1xuICAvLyAgIGNvbnN0IHhNYXggPSBtYXgobG9naXRzLCBheGlzLCB0cnVlKTtcbiAgLy8gICBjb25zdCBzaGlmdGVkID0gc3ViKGxvZ2l0cywgeE1heCk7XG4gIC8vICAgY29uc3QgdmFsdWUgPVxuICAvLyAgICAgICBzdWIoY2FzdChzaGlmdGVkLCAnZmxvYXQzMicpLCBsb2coc3VtKGV4cChzaGlmdGVkKSwgYXhpcyxcbiAgLy8gICAgICAga2VlcERpbXMpKSk7XG4gIC8vICAgc2F2ZShbdmFsdWVdKTtcbiAgLy8gICByZXR1cm4gdmFsdWU7XG4gIC8vIH07XG5cbiAgLy8gVXNlIGEgY3VzdG9tIGdyYWRpZW50IGZvciBudW1lcmljYWwgc3RhYmlsaXR5LlxuICBjb25zdCBjdXN0b21PcCA9IGN1c3RvbUdyYWQoKGxvZ2l0czogVGVuc29yLCBzYXZlOiBHcmFkU2F2ZUZ1bmMpID0+IHtcbiAgICBjb25zdCBrZWVwRGltcyA9IHRydWU7XG4gICAgY29uc3QgeE1heCA9IG1heChsb2dpdHMsIGF4aXMsIHRydWUpO1xuICAgIGNvbnN0IHNoaWZ0ZWQgPSBzdWIobG9naXRzLCB4TWF4KTtcbiAgICBjb25zdCB2YWx1ZSA9XG4gICAgICAgIHN1YihjYXN0KHNoaWZ0ZWQsICdmbG9hdDMyJyksIGxvZyhzdW0oZXhwKHNoaWZ0ZWQpLCBheGlzLCBrZWVwRGltcykpKTtcbiAgICBzYXZlKFt2YWx1ZV0pO1xuXG4gICAgY29uc3QgZ3JhZEZ1bmMgPSAoZHk6IFRlbnNvciwgc2F2ZWQ6IFRlbnNvcltdKSA9PiB7XG4gICAgICBjb25zdCBbdmFsdWVdID0gc2F2ZWQ7XG4gICAgICBjb25zdCBrZWVwRGltcyA9IHRydWU7XG4gICAgICBjb25zdCBzb2Z0bWF4ID0gZXhwKHZhbHVlKTtcbiAgICAgIHJldHVybiBzdWIoZHksIG11bChzdW0oZHksIGF4aXMsIGtlZXBEaW1zKSwgc29mdG1heCkpO1xuICAgIH07XG4gICAgcmV0dXJuIHt2YWx1ZSwgZ3JhZEZ1bmN9O1xuICB9KTtcblxuICByZXR1cm4gY3VzdG9tT3AoJGxvZ2l0cykgYXMgVDtcblxuICAvLyBUT0RPIFVzZSBFbmdpbmUucnVuS2VybmVsIHdoZW4gQ1BVL1dlYkdML1dBU00gYmFja2VuZHMgaW1wbGVtZW50IHRoaXMuXG4gIC8vIGNvbnN0IGlucHV0czogTG9nU29mdG1heElucHV0cyA9IHtsb2dpdHM6ICRsb2dpdHN9O1xuICAvLyBjb25zdCBhdHRyczogTG9nU29mdG1heEF0dHJzID0ge2F4aXN9O1xuICAvLyByZXR1cm4gRU5HSU5FLnJ1bktlcm5lbChcbiAgLy8gICAgICAgICAgICBMb2dTb2Z0bWF4LCBpbnB1dHMgYXMge30gYXMgTmFtZWRUZW5zb3JNYXAsXG4gIC8vICAgICAgICAgICAgYXR0cnMgYXMge30gYXMgTmFtZWRBdHRyTWFwKTtcbn1cblxuZXhwb3J0IGNvbnN0IGxvZ1NvZnRtYXggPSBvcCh7bG9nU29mdG1heF99KTtcbiJdfQ== |
\ | No newline at end of file |