UNPKG

2.13 kBJavaScriptView Raw
1import { computeStrides, sizeFromShape } from '../util';
2/**
3 * Validate gather nd inputs.
4 *
5 * @param tensor The tensor contains the source values.
6 * @param indices The tensor contains the indices to slice the source.
7 *
8 * @returns [resultShape, numUpdates, sliceSize, strides]
9 */
10export function prepareAndValidate(tensor, indices) {
11 const tensorRank = tensor.shape.length;
12 const indicesRank = indices.shape.length;
13 if (tensorRank < 1) {
14 throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
15 ` but the rank was ${tensorRank}.`);
16 }
17 if (indicesRank < 1) {
18 throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
19 ` but the rank was ${indicesRank}.`);
20 }
21 if (indices.dtype !== 'int32') {
22 throw new Error('tf.gatherND() expects the indices to be int32 type,' +
23 ` but the dtype was ${indices.dtype}.`);
24 }
25 if (indices.shape[indicesRank - 1] > tensorRank) {
26 throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
27 `${indices.shape[indicesRank - 1]} vs. ${tensorRank}`);
28 }
29 if (sizeFromShape(tensor.shape) === 0) {
30 throw new Error('Requested more than 0 entries, but input is empty.' +
31 ` Input shape: ${tensor.shape}.`);
32 }
33 const indicesShape = indices.shape;
34 const sliceRank = indicesShape[indicesShape.length - 1];
35 // The result shape is
36 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
37 let nResult = 1;
38 for (let i = 0; i < indicesShape.length - 1; ++i) {
39 nResult *= indicesShape[i];
40 }
41 const inputShape = tensor.shape;
42 const resultShape = indicesShape.slice();
43 resultShape.pop();
44 let sliceSize = 1;
45 for (let i = sliceRank; i < tensorRank; ++i) {
46 sliceSize *= inputShape[i];
47 resultShape.push(inputShape[i]);
48 }
49 const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
50 1].slice(0, sliceRank);
51 return [resultShape, nResult, sliceSize, strides];
52}
53//# sourceMappingURL=gather_nd_util.js.map
\No newline at end of file