UNPKG

2.96 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { Tile } from '../kernel_names';
19import { convertToTensor } from '../tensor_util_env';
20import { clone } from './clone';
21import { op } from './operation';
22import { reshape } from './reshape';
23/**
24 * Broadcast an array to a compatible shape NumPy-style.
25 *
26 * The tensor's shape is compared to the broadcast shape from end to beginning.
27 * Ones are prepended to the tensor's shape until is has the same length as
28 * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
29 * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
30 * the input tensor is tiled N times along that axis (using tf.tile).
31 *
32 * @param input The tensor that is to be broadcasted.
33 * @param shape The input is to be broadcast to this shape.
34 *
35 * @doc {heading: 'Tensors', subheading: 'Transformations'}
36 */
37function broadcastTo_(x, shape) {
38 let input = convertToTensor(x, 'broadcastTo', 'x');
39 const xShape = input.shape;
40 if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
41 throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
42 }
43 if (shape.length < input.rank) {
44 throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
45 }
46 if (shape.length > input.rank) {
47 const newShape = input.shape.slice();
48 while (newShape.length < shape.length) {
49 newShape.unshift(1);
50 }
51 input = reshape(input, newShape);
52 }
53 const inputShape = input.shape;
54 const reps = Array.from(shape);
55 for (let i = shape.length - 1; i >= 0; i--) {
56 if (inputShape[i] === shape[i]) {
57 reps[i] = 1;
58 }
59 else if (input.shape[i] !== 1) {
60 throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
61 }
62 }
63 const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
64 if (axes.length === 0) {
65 return clone(input);
66 }
67 // TODO call broadcastTo kernel directly once backends implement broadcstTo
68 const inputs = { x: input };
69 const attrs = { reps };
70 return ENGINE.runKernel(Tile, inputs, attrs);
71}
72export const broadcastTo = op({ broadcastTo_ });
73//# sourceMappingURL=broadcast_to.js.map
\No newline at end of file