UNPKG

10.2 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { ENGINE } from '../engine';
18import { Tile } from '../kernel_names';
19import { convertToTensor } from '../tensor_util_env';
20import { clone } from './clone';
21import { op } from './operation';
22import { reshape } from './reshape';
23/**
24 * Broadcast an array to a compatible shape NumPy-style.
25 *
26 * The tensor's shape is compared to the broadcast shape from end to beginning.
27 * Ones are prepended to the tensor's shape until is has the same length as
28 * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
29 * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
30 * the input tensor is tiled N times along that axis (using tf.tile).
31 *
32 * @param input The tensor that is to be broadcasted.
33 * @param shape The input is to be broadcast to this shape.
34 *
35 * @doc {heading: 'Tensors', subheading: 'Transformations'}
36 */
37function broadcastTo_(x, shape) {
38 let input = convertToTensor(x, 'broadcastTo', 'x');
39 const xShape = input.shape;
40 if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
41 throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
42 }
43 if (shape.length < input.rank) {
44 throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
45 }
46 if (shape.length > input.rank) {
47 const newShape = input.shape.slice();
48 while (newShape.length < shape.length) {
49 newShape.unshift(1);
50 }
51 input = reshape(input, newShape);
52 }
53 const inputShape = input.shape;
54 const reps = Array.from(shape);
55 for (let i = shape.length - 1; i >= 0; i--) {
56 if (inputShape[i] === shape[i]) {
57 reps[i] = 1;
58 }
59 else if (input.shape[i] !== 1) {
60 throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
61 }
62 }
63 const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
64 if (axes.length === 0) {
65 return clone(input);
66 }
67 // TODO call broadcastTo kernel directly once backends implement broadcstTo
68 const inputs = { x: input };
69 const attrs = { reps };
70 return ENGINE.runKernel(Tile, inputs, attrs);
71}
72export const broadcastTo = op({ broadcastTo_ });
73//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYnJvYWRjYXN0X3RvLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYnJvYWRjYXN0X3RvLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLElBQUksRUFBd0IsTUFBTSxpQkFBaUIsQ0FBQztBQUk1RCxPQUFPLEVBQUMsZUFBZSxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFHbkQsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUM5QixPQUFPLEVBQUMsRUFBRSxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFFbEM7Ozs7Ozs7Ozs7Ozs7R0FhRztBQUNILFNBQVMsWUFBWSxDQUNqQixDQUFvQixFQUFFLEtBQWtCO0lBQzFDLElBQUksS0FBSyxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsYUFBYSxFQUFFLEdBQUcsQ0FBQyxDQUFDO0lBQ25ELE1BQU0sTUFBTSxHQUFHLEtBQUssQ0FBQyxLQUFLLENBQUM7SUFFM0IsSUFBSSxLQUFLLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFO1FBQzVDLE1BQU0sSUFBSSxLQUFLLENBQUMsMkNBQTJDLEtBQUssSUFBSSxDQUFDLENBQUM7S0FDdkU7SUFFRCxJQUFJLEtBQUssQ0FBQyxNQUFNLEdBQUcsS0FBSyxDQUFDLElBQUksRUFBRTtRQUM3QixNQUFNLElBQUksS0FBSyxDQUFDLCtCQUErQixLQUFLLENBQUMsTUFBTSxpQkFDdkQsS0FBSyxDQUFDLElBQUksR0FBRyxDQUFDLENBQUM7S0FDcEI7SUFFRCxJQUFJLEtBQUssQ0FBQyxNQUFNLEdBQUcsS0FBSyxDQUFDLElBQUksRUFBRTtRQUM3QixNQUFNLFFBQVEsR0FBRyxLQUFLLENBQUMsS0FBSyxDQUFDLEtBQUssRUFBRSxDQUFDO1FBQ3JDLE9BQU8sUUFBUSxDQUFDLE1BQU0sR0FBRyxLQUFLLENBQUMsTUFBTSxFQUFFO1lBQ3JDLFFBQVEsQ0FBQyxPQUFPLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDckI7UUFDRCxLQUFLLEdBQUcsT0FBTyxDQUFDLEtBQUssRUFBRSxRQUFRLENBQUMsQ0FBQztLQUNsQztJQUVELE1BQU0sVUFBVSxHQUFHLEtBQUssQ0FBQyxLQUFLLENBQUM7SUFDL0IsTUFBTSxJQUFJLEdBQWEsS0FBSyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN6QyxLQUFLLElBQUksQ0FBQyxHQUFHLEtBQUssQ0FBQyxNQUFNLEdBQUcsQ0FBQyxFQUFFLENBQUMsSUFBSSxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUU7UUFDMUMsSUFBSSxVQUFVLENBQUMsQ0FBQyxDQUFDLEtBQUssS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFO1lBQzlCLElBQUksQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7U0FDYjthQUFNLElBQUksS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDL0IsTUFBTSxJQUFJLEtBQUssQ0FDWCxtQkFBbUIsTUFBTSw2QkFBNkIsS0FBSyxJQUFJLENBQUMsQ0FBQztTQUN0RTtLQUNGO0lBQ0QsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUM7SUFFcEUsSUFBSSxJQUFJLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtRQUNyQixPQUFPLEtBQUssQ0FBQyxLQUFLLENBQWMsQ0FBQztLQUNsQztJQUVELDJFQUEyRTtJQUMzRSxNQUFNLE1BQU0sR0FBZSxFQUFDLENBQUMsRUFBRSxLQUFLLEVBQUMsQ0FBQztJQUN0QyxNQUFNLEtBQUssR0FBYyxFQUFDLElBQUksRUFBQyxDQUFDO0lBQ2hDLE9BQU8sTUFBTSxDQUFDLFNBQVMsQ0FDbkIsSUFBSSxFQUFFLE1BQThCLEVBQUUsS0FBZ0MsQ0FBQyxDQUFDO0FBQzlFLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxXQUFXLEdBQUcsRUFBRSxDQUFDLEVBQUMsWUFBWSxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtFTkdJTkV9IGZyb20gJy4uL2VuZ2luZSc7XG5pbXBvcnQge1RpbGUsIFRpbGVBdHRycywgVGlsZUlucHV0c30gZnJvbSAnLi4va2VybmVsX25hbWVzJztcbmltcG9ydCB7TmFtZWRBdHRyTWFwfSBmcm9tICcuLi9rZXJuZWxfcmVnaXN0cnknO1xuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4uL3RlbnNvcic7XG5pbXBvcnQge05hbWVkVGVuc29yTWFwfSBmcm9tICcuLi90ZW5zb3JfdHlwZXMnO1xuaW1wb3J0IHtjb252ZXJ0VG9UZW5zb3J9IGZyb20gJy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1JhbmssIFNoYXBlTWFwLCBUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5cbmltcG9ydCB7Y2xvbmV9IGZyb20gJy4vY2xvbmUnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL3Jlc2hhcGUnO1xuXG4vKipcbiAqIEJyb2FkY2FzdCBhbiBhcnJheSB0byBhIGNvbXBhdGlibGUgc2hhcGUgTnVtUHktc3R5bGUuXG4gKlxuICogVGhlIHRlbnNvcidzIHNoYXBlIGlzIGNvbXBhcmVkIHRvIHRoZSBicm9hZGNhc3Qgc2hhcGUgZnJvbSBlbmQgdG8gYmVnaW5uaW5nLlxuICogT25lcyBhcmUgcHJlcGVuZGVkIHRvIHRoZSB0ZW5zb3IncyBzaGFwZSB1bnRpbCBpcyBoYXMgdGhlIHNhbWUgbGVuZ3RoIGFzXG4gKiB0aGUgYnJvYWRjYXN0IHNoYXBlLiBJZiBpbnB1dC5zaGFwZVtpXT09c2hhcGVbaV0sIHRoZSAoaSsxKS10aCBheGlzIGlzXG4gKiBhbHJlYWR5IGJyb2FkY2FzdC1jb21wYXRpYmxlLiBJZiBpbnB1dC5zaGFwZVtpXT09MSBhbmQgc2hhcGVbaV09PU4sIHRoZW5cbiAqIHRoZSBpbnB1dCB0ZW5zb3IgaXMgdGlsZWQgTiB0aW1lcyBhbG9uZyB0aGF0IGF4aXMgKHVzaW5nIHRmLnRpbGUpLlxuICpcbiAqIEBwYXJhbSBpbnB1dCBUaGUgdGVuc29yIHRoYXQgaXMgdG8gYmUgYnJvYWRjYXN0ZWQuXG4gKiBAcGFyYW0gc2hhcGUgVGhlIGlucHV0IGlzIHRvIGJlIGJyb2FkY2FzdCB0byB0aGlzIHNoYXBlLlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdUZW5zb3JzJywgc3ViaGVhZGluZzogJ1RyYW5zZm9ybWF0aW9ucyd9XG4gKi9cbmZ1bmN0aW9uIGJyb2FkY2FzdFRvXzxSIGV4dGVuZHMgUmFuaz4oXG4gICAgeDogVGVuc29yfFRlbnNvckxpa2UsIHNoYXBlOiBTaGFwZU1hcFtSXSk6IFRlbnNvcjxSPiB7XG4gIGxldCBpbnB1dCA9IGNvbnZlcnRUb1RlbnNvcih4LCAnYnJvYWRjYXN0VG8nLCAneCcpO1xuICBjb25zdCB4U2hhcGUgPSBpbnB1dC5zaGFwZTtcblxuICBpZiAoc2hhcGUuc29tZShkID0+ICEoZCA+IDApIHx8IGQgJSAxICE9PSAwKSkge1xuICAgIHRocm93IG5ldyBFcnJvcihgYnJvYWRjYXN0VG8oKTogSW52YWxpZCBicm9hZGNhc3Qgc2hhcGUgWyR7c2hhcGV9XS5gKTtcbiAgfVxuXG4gIGlmIChzaGFwZS5sZW5ndGggPCBpbnB1dC5yYW5rKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKGBicm9hZGNhc3RUbygpOiBzaGFwZS5sZW5ndGg9JHtzaGFwZS5sZW5ndGh9IDwgaW5wdXQucmFuaz0ke1xuICAgICAgICBpbnB1dC5yYW5rfS5gKTtcbiAgfVxuXG4gIGlmIChzaGFwZS5sZW5ndGggPiBpbnB1dC5yYW5rKSB7XG4gICAgY29uc3QgbmV3U2hhcGUgPSBpbnB1dC5zaGFwZS5zbGljZSgpO1xuICAgIHdoaWxlIChuZXdTaGFwZS5sZW5ndGggPCBzaGFwZS5sZW5ndGgpIHtcbiAgICAgIG5ld1NoYXBlLnVuc2hpZnQoMSk7XG4gICAgfVxuICAgIGlucHV0ID0gcmVzaGFwZShpbnB1dCwgbmV3U2hhcGUpO1xuICB9XG5cbiAgY29uc3QgaW5wdXRTaGFwZSA9IGlucHV0LnNoYXBlO1xuICBjb25zdCByZXBzOiBudW1iZXJbXSA9IEFycmF5LmZyb20oc2hhcGUpO1xuICBmb3IgKGxldCBpID0gc2hhcGUubGVuZ3RoIC0gMTsgaSA+PSAwOyBpLS0pIHtcbiAgICBpZiAoaW5wdXRTaGFwZVtpXSA9PT0gc2hhcGVbaV0pIHtcbiAgICAgIHJlcHNbaV0gPSAxO1xuICAgIH0gZWxzZSBpZiAoaW5wdXQuc2hhcGVbaV0gIT09IDEpIHtcbiAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICBgYnJvYWRjYXN0VG8oKTogWyR7eFNoYXBlfV0gY2Fubm90IGJlIGJyb2FkY2FzdCB0byBbJHtzaGFwZX1dLmApO1xuICAgIH1cbiAgfVxuICBjb25zdCBheGVzID0gcmVwcy5tYXAoKG4sIGkpID0+IG4gPiAxID8gaSA6IC0xKS5maWx0ZXIoaSA9PiBpID49IDApO1xuXG4gIGlmIChheGVzLmxlbmd0aCA9PT0gMCkge1xuICAgIHJldHVybiBjbG9uZShpbnB1dCkgYXMgVGVuc29yPFI+O1xuICB9XG5cbiAgLy8gVE9ETyBjYWxsIGJyb2FkY2FzdFRvIGtlcm5lbCBkaXJlY3RseSBvbmNlIGJhY2tlbmRzIGltcGxlbWVudCBicm9hZGNzdFRvXG4gIGNvbnN0IGlucHV0czogVGlsZUlucHV0cyA9IHt4OiBpbnB1dH07XG4gIGNvbnN0IGF0dHJzOiBUaWxlQXR0cnMgPSB7cmVwc307XG4gIHJldHVybiBFTkdJTkUucnVuS2VybmVsKFxuICAgICAgVGlsZSwgaW5wdXRzIGFzIHt9IGFzIE5hbWVkVGVuc29yTWFwLCBhdHRycyBhcyB1bmtub3duIGFzIE5hbWVkQXR0ck1hcCk7XG59XG5cbmV4cG9ydCBjb25zdCBicm9hZGNhc3RUbyA9IG9wKHticm9hZGNhc3RUb199KTtcbiJdfQ==
\No newline at end of file