UNPKG

10.1 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { convertToTensor } from '../tensor_util_env';
18import * as util from '../util';
19import { matMul } from './mat_mul';
20import { op } from './operation';
21import { reshape } from './reshape';
22/**
23 * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
24 *
25 * ```js
26 * const a = tf.tensor1d([1, 2]);
27 * const b = tf.tensor2d([[1, 2], [3, 4]]);
28 * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
29 *
30 * a.dot(b).print(); // or tf.dot(a, b)
31 * b.dot(a).print();
32 * b.dot(c).print();
33 * ```
34 * @param t1 The first tensor in the dot operation.
35 * @param t2 The second tensor in the dot operation.
36 *
37 * @doc {heading: 'Operations', subheading: 'Matrices'}
38 */
39function dot_(t1, t2) {
40 const $t1 = convertToTensor(t1, 't1', 'dot');
41 const $t2 = convertToTensor(t2, 't2', 'dot');
42 util.assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
43 `${$t1.rank} and ${$t2.rank}.`);
44 const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
45 const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
46 util.assert(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
47 `${t1Inner} and ${t2Inner}.`);
48 if ($t1.rank === 1 && $t2.rank === 1) {
49 const t12D = reshape($t1, [1, -1]);
50 const t22D = reshape($t2, [-1, 1]);
51 const t1t2 = matMul(t12D, t22D);
52 return reshape(t1t2, []);
53 }
54 else if ($t1.rank === 1 && $t2.rank === 2) {
55 const t12D = reshape($t1, [1, -1]);
56 const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
57 const t1t2 = matMul(t12D, t22D);
58 return reshape(t1t2, [t1t2.size]);
59 }
60 else if ($t1.rank === 2 && $t2.rank === 1) {
61 const t22D = reshape($t2, [-1, 1]);
62 const t1t2 = matMul($t1, t22D);
63 return reshape(t1t2, [t1t2.size]);
64 }
65 else {
66 const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
67 const t1t2 = matMul($t1, t22D);
68 return t1t2;
69 }
70}
71export const dot = op({ dot_ });
72//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZG90LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvZG90LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUVuRCxPQUFPLEtBQUssSUFBSSxNQUFNLFNBQVMsQ0FBQztBQUVoQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2pDLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUVsQzs7Ozs7Ozs7Ozs7Ozs7OztHQWdCRztBQUNILFNBQVMsSUFBSSxDQUFDLEVBQXFCLEVBQUUsRUFBcUI7SUFDeEQsTUFBTSxHQUFHLEdBQUcsZUFBZSxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDN0MsTUFBTSxHQUFHLEdBQUcsZUFBZSxDQUFDLEVBQUUsRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFN0MsSUFBSSxDQUFDLE1BQU0sQ0FDUCxDQUFDLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxJQUFJLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsQ0FBQyxFQUN4RSxHQUFHLEVBQUUsQ0FBQyw4REFBOEQ7UUFDaEUsR0FBRyxHQUFHLENBQUMsSUFBSSxRQUFRLEdBQUcsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxDQUFDO0lBRXhDLE1BQU0sT0FBTyxHQUFHLENBQUMsR0FBRyxDQUFDLElBQUksS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUMzRCxNQUFNLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFFM0QsSUFBSSxDQUFDLE1BQU0sQ0FDUCxPQUFPLEtBQUssT0FBTyxFQUNuQixHQUFHLEVBQUUsQ0FBQywrREFBK0Q7UUFDakUsR0FBRyxPQUFPLFFBQVEsT0FBTyxHQUFHLENBQUMsQ0FBQztJQUV0QyxJQUFJLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxJQUFJLEdBQUcsQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQ3BDLE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ25DLE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ25DLE1BQU0sSUFBSSxHQUFHLE1BQU0sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7UUFDaEMsT0FBTyxPQUFPLENBQUMsSUFBSSxFQUFFLEVBQUUsQ0FBQyxDQUFDO0tBQzFCO1NBQU0sSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsRUFBRTtRQUMzQyxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNuQyxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4RCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQ2hDLE9BQU8sT0FBTyxDQUFDLElBQUksRUFBRSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0tBQ25DO1NBQU0sSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxLQUFLLENBQUMsRUFBRTtRQUMzQyxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNuQyxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsR0FBRyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQy9CLE9BQU8sT0FBTyxDQUFDLElBQUksRUFBRSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO0tBQ25DO1NBQU07UUFDTCxNQUFNLElBQUksR0FBRyxPQUFPLENBQUMsR0FBRyxFQUFFLENBQUMsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4RCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsR0FBRyxFQUFFLElBQUksQ0FBQyxDQUFDO1FBQy9CLE9BQU8sSUFBSSxDQUFDO0tBQ2I7QUFDSCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sR0FBRyxHQUFHLEVBQUUsQ0FBQyxFQUFDLElBQUksRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7VGVuc29yLH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5pbXBvcnQgKiBhcyB1dGlsIGZyb20gJy4uL3V0aWwnO1xuXG5pbXBvcnQge21hdE11bH0gZnJvbSAnLi9tYXRfbXVsJztcbmltcG9ydCB7b3B9IGZyb20gJy4vb3BlcmF0aW9uJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9yZXNoYXBlJztcblxuLyoqXG4gKiBDb21wdXRlcyB0aGUgZG90IHByb2R1Y3Qgb2YgdHdvIG1hdHJpY2VzIGFuZC9vciB2ZWN0b3JzLCBgdDFgIGFuZCBgdDJgLlxuICpcbiAqIGBgYGpzXG4gKiBjb25zdCBhID0gdGYudGVuc29yMWQoWzEsIDJdKTtcbiAqIGNvbnN0IGIgPSB0Zi50ZW5zb3IyZChbWzEsIDJdLCBbMywgNF1dKTtcbiAqIGNvbnN0IGMgPSB0Zi50ZW5zb3IyZChbWzEsIDIsIDNdLCBbNCwgNSwgNl1dKTtcbiAqXG4gKiBhLmRvdChiKS5wcmludCgpOyAgLy8gb3IgdGYuZG90KGEsIGIpXG4gKiBiLmRvdChhKS5wcmludCgpO1xuICogYi5kb3QoYykucHJpbnQoKTtcbiAqIGBgYFxuICogQHBhcmFtIHQxIFRoZSBmaXJzdCB0ZW5zb3IgaW4gdGhlIGRvdCBvcGVyYXRpb24uXG4gKiBAcGFyYW0gdDIgVGhlIHNlY29uZCB0ZW5zb3IgaW4gdGhlIGRvdCBvcGVyYXRpb24uXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ09wZXJhdGlvbnMnLCBzdWJoZWFkaW5nOiAnTWF0cmljZXMnfVxuICovXG5mdW5jdGlvbiBkb3RfKHQxOiBUZW5zb3J8VGVuc29yTGlrZSwgdDI6IFRlbnNvcnxUZW5zb3JMaWtlKTogVGVuc29yIHtcbiAgY29uc3QgJHQxID0gY29udmVydFRvVGVuc29yKHQxLCAndDEnLCAnZG90Jyk7XG4gIGNvbnN0ICR0MiA9IGNvbnZlcnRUb1RlbnNvcih0MiwgJ3QyJywgJ2RvdCcpO1xuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgKCR0MS5yYW5rID09PSAxIHx8ICR0MS5yYW5rID09PSAyKSAmJiAoJHQyLnJhbmsgPT09IDEgfHwgJHQyLnJhbmsgPT09IDIpLFxuICAgICAgKCkgPT4gYEVycm9yIGluIGRvdDogaW5wdXRzIG11c3QgYWxsIGJlIHJhbmsgMSBvciAyLCBidXQgZ290IHJhbmtzIGAgK1xuICAgICAgICAgIGAkeyR0MS5yYW5rfSBhbmQgJHskdDIucmFua30uYCk7XG5cbiAgY29uc3QgdDFJbm5lciA9ICgkdDEucmFuayA9PT0gMSA/ICR0MS5zaXplIDogJHQxLnNoYXBlWzFdKTtcbiAgY29uc3QgdDJJbm5lciA9ICgkdDIucmFuayA9PT0gMSA/ICR0Mi5zaXplIDogJHQyLnNoYXBlWzBdKTtcblxuICB1dGlsLmFzc2VydChcbiAgICAgIHQxSW5uZXIgPT09IHQySW5uZXIsXG4gICAgICAoKSA9PiBgRXJyb3IgaW4gZG90OiBpbm5lciBkaW1lbnNpb25zIG9mIGlucHV0cyBtdXN0IG1hdGNoLCBidXQgZ290IGAgK1xuICAgICAgICAgIGAke3QxSW5uZXJ9IGFuZCAke3QySW5uZXJ9LmApO1xuXG4gIGlmICgkdDEucmFuayA9PT0gMSAmJiAkdDIucmFuayA9PT0gMSkge1xuICAgIGNvbnN0IHQxMkQgPSByZXNoYXBlKCR0MSwgWzEsIC0xXSk7XG4gICAgY29uc3QgdDIyRCA9IHJlc2hhcGUoJHQyLCBbLTEsIDFdKTtcbiAgICBjb25zdCB0MXQyID0gbWF0TXVsKHQxMkQsIHQyMkQpO1xuICAgIHJldHVybiByZXNoYXBlKHQxdDIsIFtdKTtcbiAgfSBlbHNlIGlmICgkdDEucmFuayA9PT0gMSAmJiAkdDIucmFuayA9PT0gMikge1xuICAgIGNvbnN0IHQxMkQgPSByZXNoYXBlKCR0MSwgWzEsIC0xXSk7XG4gICAgY29uc3QgdDIyRCA9IHJlc2hhcGUoJHQyLCBbJHQyLnNoYXBlWzBdLCAkdDIuc2hhcGVbMV1dKTtcbiAgICBjb25zdCB0MXQyID0gbWF0TXVsKHQxMkQsIHQyMkQpO1xuICAgIHJldHVybiByZXNoYXBlKHQxdDIsIFt0MXQyLnNpemVdKTtcbiAgfSBlbHNlIGlmICgkdDEucmFuayA9PT0gMiAmJiAkdDIucmFuayA9PT0gMSkge1xuICAgIGNvbnN0IHQyMkQgPSByZXNoYXBlKCR0MiwgWy0xLCAxXSk7XG4gICAgY29uc3QgdDF0MiA9IG1hdE11bCgkdDEsIHQyMkQpO1xuICAgIHJldHVybiByZXNoYXBlKHQxdDIsIFt0MXQyLnNpemVdKTtcbiAgfSBlbHNlIHtcbiAgICBjb25zdCB0MjJEID0gcmVzaGFwZSgkdDIsIFskdDIuc2hhcGVbMF0sICR0Mi5zaGFwZVsxXV0pO1xuICAgIGNvbnN0IHQxdDIgPSBtYXRNdWwoJHQxLCB0MjJEKTtcbiAgICByZXR1cm4gdDF0MjtcbiAgfVxufVxuXG5leHBvcnQgY29uc3QgZG90ID0gb3Aoe2RvdF99KTtcbiJdfQ==
\No newline at end of file