1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { convertToTensor } from '../tensor_util_env';
|
18 | import * as util from '../util';
|
19 | import { matMul } from './mat_mul';
|
20 | import { op } from './operation';
|
21 | import { reshape } from './reshape';
|
22 | /**
|
23 | * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
|
24 | *
|
25 | * ```js
|
26 | * const a = tf.tensor1d([1, 2]);
|
27 | * const b = tf.tensor2d([[1, 2], [3, 4]]);
|
28 | * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
|
29 | *
|
30 | * a.dot(b).print(); // or tf.dot(a, b)
|
31 | * b.dot(a).print();
|
32 | * b.dot(c).print();
|
33 | * ```
|
34 | * @param t1 The first tensor in the dot operation.
|
35 | * @param t2 The second tensor in the dot operation.
|
36 | *
|
37 | * @doc {heading: 'Operations', subheading: 'Matrices'}
|
38 | */
|
39 | function dot_(t1, t2) {
|
40 | const $t1 = convertToTensor(t1, 't1', 'dot');
|
41 | const $t2 = convertToTensor(t2, 't2', 'dot');
|
42 | util.assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
|
43 | `${$t1.rank} and ${$t2.rank}.`);
|
44 | const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
|
45 | const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
|
46 | util.assert(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
|
47 | `${t1Inner} and ${t2Inner}.`);
|
48 | if ($t1.rank === 1 && $t2.rank === 1) {
|
49 | const t12D = reshape($t1, [1, -1]);
|
50 | const t22D = reshape($t2, [-1, 1]);
|
51 | const t1t2 = matMul(t12D, t22D);
|
52 | return reshape(t1t2, []);
|
53 | }
|
54 | else if ($t1.rank === 1 && $t2.rank === 2) {
|
55 | const t12D = reshape($t1, [1, -1]);
|
56 | const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
|
57 | const t1t2 = matMul(t12D, t22D);
|
58 | return reshape(t1t2, [t1t2.size]);
|
59 | }
|
60 | else if ($t1.rank === 2 && $t2.rank === 1) {
|
61 | const t22D = reshape($t2, [-1, 1]);
|
62 | const t1t2 = matMul($t1, t22D);
|
63 | return reshape(t1t2, [t1t2.size]);
|
64 | }
|
65 | else {
|
66 | const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
|
67 | const t1t2 = matMul($t1, t22D);
|
68 | return t1t2;
|
69 | }
|
70 | }
|
71 | export const dot = op({ dot_ });
|
72 | //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZG90LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvZG90LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUVuRCxPQUFPLEtBQUssSUFBSSxNQUFNLFNBQVMsQ0FBQztBQUVoQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUVsQzs7Ozs7Ozs7Ozs7Ozs7OztHQWdCRztBQUNILFNBQVMsSUFBSSxDQUFDLEVBQXFCLEVBQUUsRUFBcUI7SUFDeEQsTUFBTSxHQUFHLEdBQUcsZUFBZSxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDN0MsTUFBTSxHQUFHLEdBQUcsZUFBZSxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFN0MsSUFBSSxDQUFDLE1BQU0sQ0FDUCxDQUFDLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxJQUFJLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsQ0FBQyxFQUN4RSxHQUFHLEVBQUUsQ0FBQyw4REFBOEQ7UUFDaEUsR0FBRyxHQUFHLENBQUMsSUFBSSxRQUFRLEdBQUcsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxDQUFDO0lBRXhDLE1BQU0sT0FBTyxHQUFHLENBQUMsR0FBRyxDQUFDLElBQUksS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUMzRCxNQUFNLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFFM0QsSUFBSSxDQUFDLE1BQU0sQ0FDUCxPQUFPLEtBQUssT0FBTyxFQUNuQixHQUFHLEVBQUUsQ0FBQywrREFBK0Q7UUFDakUsR0FBRyxPQUFPLFFBQVEsT0FBTyxHQUFHLENBQUMsQ0FBQztJQUV0QyxJQUFJLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxJQUFJLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQ3BDLE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ25DLE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ25DLE1BQU0sSUFBSSxHQUFHLE1BQU0sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7UUFDaEMsT0FBTyxPQUFPLENBQUMsSUFBSSxFQUFFLEVBQUUsQ0FBQyxDQUFDO0tBQzFCO1NBQU0sSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsRUFBRTtRQUMzQyxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNuQyxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4RCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQ2hDLE9BQU8sT0FBTyxDQUFDLElBQUksRUFBRSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0tBQ25DO1NBQU0sSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsRUFBRTtRQUMzQyxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNuQyxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsR0FBRyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQy9CLE9BQU8sT0FBTyxDQUFDLElBQUksRUFBRSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0tBQ25DO1NBQU07UUFDTCxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4RCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsR0FBRyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQy9CLE9BQU8sSUFBSSxDQUFDO0tBQ2I7QUFDSCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sR0FBRyxHQUFHLEVBQUUsQ0FBQyxFQUFDLElBQUksRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7VGVuc29yLH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5pbXBvcnQgKiBhcyB1dGlsIGZyb20gJy4uL3V0aWwnO1xuXG5pbXBvcnQge21hdE11bH0gZnJvbSAnLi9tYXRfbXVsJztcbmltcG9ydCB7b3B9IGZyb20gJy4vb3BlcmF0aW9uJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9yZXNoYXBlJztcblxuLyoqXG4gKiBDb21wdXRlcyB0aGUgZG90IHByb2R1Y3Qgb2YgdHdvIG1hdHJpY2VzIGFuZC9vciB2ZWN0b3JzLCBgdDFgIGFuZCBgdDJgLlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCBhID0gdGYudGVuc29yMWQoWzEsIDJdKTtcbiAqIGNvbnN0IGIgPSB0Zi50ZW5zb3IyZChbWzEsIDJdLCBbMywgNF1dKTtcbiAqIGNvbnN0IGMgPSB0Zi50ZW5zb3IyZChbWzEsIDIsIDNdLCBbNCwgNSwgNl1dKTtcbiAqXG4gKiBhLmRvdChiKS5wcmludCgpOyAgLy8gb3IgdGYuZG90KGEsIGIpXG4gKiBiLmRvdChhKS5wcmludCgpO1xuICogYi5kb3QoYykucHJpbnQoKTtcbiAqIGBgYFxuICogQHBhcmFtIHQxIFRoZSBmaXJzdCB0ZW5zb3IgaW4gdGhlIGRvdCBvcGVyYXRpb24uXG4gKiBAcGFyYW0gdDIgVGhlIHNlY29uZCB0ZW5zb3IgaW4gdGhlIGRvdCBvcGVyYXRpb24uXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ09wZXJhdGlvbnMnLCBzdWJoZWFkaW5nOiAnTWF0cmljZXMnfVxuICovXG5mdW5jdGlvbiBkb3RfKHQxOiBUZW5zb3J8VGVuc29yTGlrZSwgdDI6IFRlbnNvcnxUZW5zb3JMaWtlKTogVGVuc29yIHtcbiAgY29uc3QgJHQxID0gY29udmVydFRvVGVuc29yKHQxLCAndDEnLCAnZG90Jyk7XG4gIGNvbnN0ICR0MiA9IGNvbnZlcnRUb1RlbnNvcih0MiwgJ3QyJywgJ2RvdCcpO1xuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgKCR0MS5yYW5rID09PSAxIHx8ICR0MS5yYW5rID09PSAyKSAmJiAoJHQyLnJhbmsgPT09IDEgfHwgJHQyLnJhbmsgPT09IDIpLFxuICAgICAgKCkgPT4gYEVycm9yIGluIGRvdDogaW5wdXRzIG11c3QgYWxsIGJlIHJhbmsgMSBvciAyLCBidXQgZ290IHJhbmtzIGAgK1xuICAgICAgICAgIGAkeyR0MS5yYW5rfSBhbmQgJHskdDIucmFua30uYCk7XG5cbiAgY29uc3QgdDFJbm5lciA9ICgkdDEucmFuayA9PT0gMSA/ICR0MS5zaXplIDogJHQxLnNoYXBlWzFdKTtcbiAgY29uc3QgdDJJbm5lciA9ICgkdDIucmFuayA9PT0gMSA/ICR0Mi5zaXplIDogJHQyLnNoYXBlWzBdKTtcblxuICB1dGlsLmFzc2VydChcbiAgICAgIHQxSW5uZXIgPT09IHQySW5uZXIsXG4gICAgICAoKSA9PiBgRXJyb3IgaW4gZG90OiBpbm5lciBkaW1lbnNpb25zIG9mIGlucHV0cyBtdXN0IG1hdGNoLCBidXQgZ290IGAgK1xuICAgICAgICAgIGAke3QxSW5uZXJ9IGFuZCAke3QySW5uZXJ9LmApO1xuXG4gIGlmICgkdDEucmFuayA9PT0gMSAmJiAkdDIucmFuayA9PT0gMSkge1xuICAgIGNvbnN0IHQxMkQgPSByZXNoYXBlKCR0MSwgWzEsIC0xXSk7XG4gICAgY29uc3QgdDIyRCA9IHJlc2hhcGUoJHQyLCBbLTEsIDFdKTtcbiAgICBjb25zdCB0MXQyID0gbWF0TXVsKHQxMkQsIHQyMkQpO1xuICAgIHJldHVybiByZXNoYXBlKHQxdDIsIFtdKTtcbiAgfSBlbHNlIGlmICgkdDEucmFuayA9PT0gMSAmJiAkdDIucmFuayA9PT0gMikge1xuICAgIGNvbnN0IHQxMkQgPSByZXNoYXBlKCR0MSwgWzEsIC0xXSk7XG4gICAgY29uc3QgdDIyRCA9IHJlc2hhcGUoJHQyLCBbJHQyLnNoYXBlWzBdLCAkdDIuc2hhcGVbMV1dKTtcbiAgICBjb25zdCB0MXQyID0gbWF0TXVsKHQxMkQsIHQyMkQpO1xuICAgIHJldHVybiByZXNoYXBlKHQxdDIsIFt0MXQyLnNpemVdKTtcbiAgfSBlbHNlIGlmICgkdDEucmFuayA9PT0gMiAmJiAkdDIucmFuayA9PT0gMSkge1xuICAgIGNvbnN0IHQyMkQgPSByZXNoYXBlKCR0MiwgWy0xLCAxXSk7XG4gICAgY29uc3QgdDF0MiA9IG1hdE11bCgkdDEsIHQyMkQpO1xuICAgIHJldHVybiByZXNoYXBlKHQxdDIsIFt0MXQyLnNpemVdKTtcbiAgfSBlbHNlIHtcbiAgICBjb25zdCB0MjJEID0gcmVzaGFwZSgkdDIsIFskdDIuc2hhcGVbMF0sICR0Mi5zaGFwZVsxXV0pO1xuICAgIGNvbnN0IHQxdDIgPSBtYXRNdWwoJHQxLCB0MjJEKTtcbiAgICByZXR1cm4gdDF0MjtcbiAgfVxufVxuXG5leHBvcnQgY29uc3QgZG90ID0gb3Aoe2RvdF99KTtcbiJdfQ== |
\ | No newline at end of file |