1 | import { computeStrides, sizeFromShape } from '../util';
|
2 |
|
3 |
|
4 |
|
5 |
|
6 |
|
7 |
|
8 |
|
9 |
|
10 | export function prepareAndValidate(tensor, indices) {
|
11 | const tensorRank = tensor.shape.length;
|
12 | const indicesRank = indices.shape.length;
|
13 | if (tensorRank < 1) {
|
14 | throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
|
15 | ` but the rank was ${tensorRank}.`);
|
16 | }
|
17 | if (indicesRank < 1) {
|
18 | throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
|
19 | ` but the rank was ${indicesRank}.`);
|
20 | }
|
21 | if (indices.dtype !== 'int32') {
|
22 | throw new Error('tf.gatherND() expects the indices to be int32 type,' +
|
23 | ` but the dtype was ${indices.dtype}.`);
|
24 | }
|
25 | if (indices.shape[indicesRank - 1] > tensorRank) {
|
26 | throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
|
27 | `${indices.shape[indicesRank - 1]} vs. ${tensorRank}`);
|
28 | }
|
29 | if (sizeFromShape(tensor.shape) === 0) {
|
30 | throw new Error('Requested more than 0 entries, but input is empty.' +
|
31 | ` Input shape: ${tensor.shape}.`);
|
32 | }
|
33 | const indicesShape = indices.shape;
|
34 | const sliceRank = indicesShape[indicesShape.length - 1];
|
35 |
|
36 |
|
37 | let nResult = 1;
|
38 | for (let i = 0; i < indicesShape.length - 1; ++i) {
|
39 | nResult *= indicesShape[i];
|
40 | }
|
41 | const inputShape = tensor.shape;
|
42 | const resultShape = indicesShape.slice();
|
43 | resultShape.pop();
|
44 | let sliceSize = 1;
|
45 | for (let i = sliceRank; i < tensorRank; ++i) {
|
46 | sliceSize *= inputShape[i];
|
47 | resultShape.push(inputShape[i]);
|
48 | }
|
49 | const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
|
50 | 1].slice(0, sliceRank);
|
51 | return [resultShape, nResult, sliceSize, strides];
|
52 | }
|
53 |
|
\ | No newline at end of file |