1 | /**
|
2 | * @license
|
3 | * Copyright 2020 Google LLC. All Rights Reserved.
|
4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | * you may not use this file except in compliance with the License.
|
6 | * You may obtain a copy of the License at
|
7 | *
|
8 | * http://www.apache.org/licenses/LICENSE-2.0
|
9 | *
|
10 | * Unless required by applicable law or agreed to in writing, software
|
11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | * See the License for the specific language governing permissions and
|
14 | * limitations under the License.
|
15 | * =============================================================================
|
16 | */
|
17 | import { ENGINE } from '../engine';
|
18 | import { Tile } from '../kernel_names';
|
19 | import { convertToTensor } from '../tensor_util_env';
|
20 | import { clone } from './clone';
|
21 | import { op } from './operation';
|
22 | import { reshape } from './reshape';
|
23 | /**
|
24 | * Broadcast an array to a compatible shape NumPy-style.
|
25 | *
|
26 | * The tensor's shape is compared to the broadcast shape from end to beginning.
|
27 | * Ones are prepended to the tensor's shape until is has the same length as
|
28 | * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
|
29 | * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
|
30 | * the input tensor is tiled N times along that axis (using tf.tile).
|
31 | *
|
32 | * @param input The tensor that is to be broadcasted.
|
33 | * @param shape The input is to be broadcast to this shape.
|
34 | *
|
35 | * @doc {heading: 'Tensors', subheading: 'Transformations'}
|
36 | */
|
37 | function broadcastTo_(x, shape) {
|
38 | let input = convertToTensor(x, 'broadcastTo', 'x');
|
39 | const xShape = input.shape;
|
40 | if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
|
41 | throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
|
42 | }
|
43 | if (shape.length < input.rank) {
|
44 | throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
|
45 | }
|
46 | if (shape.length > input.rank) {
|
47 | const newShape = input.shape.slice();
|
48 | while (newShape.length < shape.length) {
|
49 | newShape.unshift(1);
|
50 | }
|
51 | input = reshape(input, newShape);
|
52 | }
|
53 | const inputShape = input.shape;
|
54 | const reps = Array.from(shape);
|
55 | for (let i = shape.length - 1; i >= 0; i--) {
|
56 | if (inputShape[i] === shape[i]) {
|
57 | reps[i] = 1;
|
58 | }
|
59 | else if (input.shape[i] !== 1) {
|
60 | throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
|
61 | }
|
62 | }
|
63 | const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
|
64 | if (axes.length === 0) {
|
65 | return clone(input);
|
66 | }
|
67 | // TODO call broadcastTo kernel directly once backends implement broadcstTo
|
68 | const inputs = { x: input };
|
69 | const attrs = { reps };
|
70 | return ENGINE.runKernel(Tile, inputs, attrs);
|
71 | }
|
72 | export const broadcastTo = op({ broadcastTo_ });
|
73 | //# sourceMappingURL=broadcast_to.js.map |
\ | No newline at end of file |