UNPKG

2.83 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { convertToTensor } from '../tensor_util_env';
18import * as util from '../util';
19import { matMul } from './mat_mul';
20import { op } from './operation';
21import { reshape } from './reshape';
22/**
23 * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
24 *
25 * ```js
26 * const a = tf.tensor1d([1, 2]);
27 * const b = tf.tensor2d([[1, 2], [3, 4]]);
28 * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
29 *
30 * a.dot(b).print(); // or tf.dot(a, b)
31 * b.dot(a).print();
32 * b.dot(c).print();
33 * ```
34 * @param t1 The first tensor in the dot operation.
35 * @param t2 The second tensor in the dot operation.
36 *
37 * @doc {heading: 'Operations', subheading: 'Matrices'}
38 */
39function dot_(t1, t2) {
40 const $t1 = convertToTensor(t1, 't1', 'dot');
41 const $t2 = convertToTensor(t2, 't2', 'dot');
42 util.assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), () => `Error in dot: inputs must all be rank 1 or 2, but got ranks ` +
43 `${$t1.rank} and ${$t2.rank}.`);
44 const t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
45 const t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
46 util.assert(t1Inner === t2Inner, () => `Error in dot: inner dimensions of inputs must match, but got ` +
47 `${t1Inner} and ${t2Inner}.`);
48 if ($t1.rank === 1 && $t2.rank === 1) {
49 const t12D = reshape($t1, [1, -1]);
50 const t22D = reshape($t2, [-1, 1]);
51 const t1t2 = matMul(t12D, t22D);
52 return reshape(t1t2, []);
53 }
54 else if ($t1.rank === 1 && $t2.rank === 2) {
55 const t12D = reshape($t1, [1, -1]);
56 const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
57 const t1t2 = matMul(t12D, t22D);
58 return reshape(t1t2, [t1t2.size]);
59 }
60 else if ($t1.rank === 2 && $t2.rank === 1) {
61 const t22D = reshape($t2, [-1, 1]);
62 const t1t2 = matMul($t1, t22D);
63 return reshape(t1t2, [t1t2.size]);
64 }
65 else {
66 const t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
67 const t1t2 = matMul($t1, t22D);
68 return t1t2;
69 }
70}
71export const dot = op({ dot_ });
72//# sourceMappingURL=dot.js.map
\No newline at end of file