UNPKG

3.19 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2021 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { Einsum } from '../kernel_names';
19import { convertToTensor } from '../tensor_util_env';
20import { op } from './operation';
21/**
22 * Tensor contraction over specified indices and outer product.
23 *
24 * `einsum` allows defining Tensors by defining their element-wise computation.
25 * This computation is based on
26 * [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
27 *
28 * Some special cases include:
29 *
30 * Matrix multiplication:
31 * ```js
32 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
33 * const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
34 * x.print();
35 * y.print();
36 * tf.einsum('ij,jk->ik', x, y).print();
37 * ```
38 *
39 * Dot product:
40 * ```js
41 * const x = tf.tensor1d([1, 2, 3]);
42 * const y = tf.tensor1d([0, 1, 2]);
43 * x.print();
44 * y.print();
45 * tf.einsum('i,i->', x, y).print();
46 * ```
47 *
48 * Batch dot product:
49 * ```js
50 * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
51 * const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
52 * x.print();
53 * y.print();
54 * tf.einsum('bi,bi->b', x, y).print();
55 * ```
56 *
57 * Outer prouduct:
58 * ```js
59 * const x = tf.tensor1d([1, 3, 5]);
60 * const y = tf.tensor1d([2, 4, 6]);
61 * x.print();
62 * y.print();
63 * tf.einsum('i,j->ij', x, y).print();
64 * ```
65 *
66 * Matrix transpose:
67 * ```js
68 * const x = tf.tensor2d([[1, 2], [3, 4]]);
69 * x.print();
70 * tf.einsum('ij->ji', x).print();
71 * ```
72 *
73 * Batch matrix transpose:
74 * ```js
75 * const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
76 * x.print();
77 * tf.einsum('bij->bji', x).print();
78 * ```
79 *
80 * Limitations:
81 *
82 * This implementation of einsum has the following limitations:
83 *
84 * - Does not support >2 input tensors.
85 * - Does not support duplicate axes for any given input tensor. E.g., equation
86 * 'ii->' is not suppoted.
87 * - The `...` notation is not supported.
88 *
89 * @param equation a string describing the contraction, in the same format as
90 * [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
91 * @param tensors the input(s) to contract (each one a Tensor), whose shapes
92 * should be consistent with equation.
93 * @returns The output tensor.
94 *
95 * @doc {heading: 'Tensors', subheading: 'Matrices'}
96 */
97export function einsum_(equation, ...tensors) {
98 const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'einsum'));
99 const attrs = { equation };
100 return ENGINE.runKernel(Einsum, $tensors, attrs);
101}
102export const einsum = op({ einsum_ });
103//# sourceMappingURL=einsum.js.map
\No newline at end of file