UNPKG

2.08 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { SquaredDifference } from '../kernel_names';
19import { makeTypesMatch } from '../tensor_util';
20import { convertToTensor } from '../tensor_util_env';
21import { assertAndGetBroadcastShape } from './broadcast_util';
22import { op } from './operation';
23/**
24 * Returns (a - b) * (a - b) element-wise.
25 * Supports broadcasting.
26 *
27 * ```js
28 * const a = tf.tensor1d([1, 4, 3, 16]);
29 * const b = tf.tensor1d([1, 2, 9, 4]);
30 *
31 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
32 * ```
33 *
34 * ```js
35 * // Broadcast squared difference a with b.
36 * const a = tf.tensor1d([2, 4, 6, 8]);
37 * const b = tf.scalar(5);
38 *
39 * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
40 * ```
41 *
42 * @param a The first tensor.
43 * @param b The second tensor. Must have the same type as `a`.
44 *
45 * @doc {heading: 'Operations', subheading: 'Arithmetic'}
46 */
47function squaredDifference_(a, b) {
48 let $a = convertToTensor(a, 'a', 'squaredDifference');
49 let $b = convertToTensor(b, 'b', 'squaredDifference');
50 [$a, $b] = makeTypesMatch($a, $b);
51 assertAndGetBroadcastShape($a.shape, $b.shape);
52 const inputs = { a: $a, b: $b };
53 const attrs = {};
54 return ENGINE.runKernel(SquaredDifference, inputs, attrs);
55}
56export const squaredDifference = op({ squaredDifference_ });
57//# sourceMappingURL=squared_difference.js.map
\No newline at end of file