UNPKG

8.83 kBSource Map (JSON)View Raw
1{"version":3,"file":"scatter_nd_util.js","sourceRoot":"","sources":["../../src/ops/scatter_nd_util.ts"],"names":[],"mappings":"AAkBA,OAAO,EAAC,cAAc,EAAE,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtD;;;;;GAKG;AACH,MAAM,UAAU,mBAAmB,CAC/B,KAAe,EAAE,OAAe,EAAE,OAAe;IACnD,MAAM,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1E,MAAM,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE3D,MAAM,UAAU,GAAG,uDAAuD;QACtE,wCAAwC,OAAO,CAAC,KAAK,EAAE;QACvD,oBAAoB,OAAO,CAAC,KAAK,YAAY,KAAK,EAAE;QACpD,eAAe,QAAQ,mBAAmB,QAAQ,GAAG,CAAC;IAE1D,IAAI,OAAO,CAAC,IAAI,GAAG,QAAQ,EAAE;QAC3B,MAAM,IAAI,KAAK,CAAC,UAAU,GAAG,kBAAkB,QAAQ,IAAI,CAAC,CAAC;KAC9D;IACD,IAAI,KAAK,CAAC,MAAM,GAAG,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,QAAQ,CAAC,EAAE;QACvD,MAAM,IAAI,KAAK,CACX,UAAU;YACV,0BAA0B,QAAQ,GAAG,CAAC,OAAO,CAAC,IAAI,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC;KACvE;IACD,IAAI,OAAO,CAAC,IAAI,KAAK,QAAQ,GAAG,KAAK,CAAC,MAAM,GAAG,QAAQ,EAAE;QACvD,MAAM,IAAI,KAAK,CACX,UAAU,GAAG,mBAAmB,QAAQ,GAAG,KAAK,CAAC,MAAM,GAAG,QAAQ,EAAE,CAAC,CAAC;KAC3E;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,EAAE,CAAC,EAAE;QACjC,IAAI,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACzC,MAAM,IAAI,KAAK,CACX,UAAU;gBACV,kBAAkB,CAAC,MAAM,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,sBAAsB,CAAC,MAC5D,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;SAC/B;KACF;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,IAAI,GAAG,QAAQ,EAAE,EAAE,CAAC,EAAE;QAChD,IAAI,OAAO,CAAC,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,KAAK,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,EAAE;YACvD,MAAM,IAAI,KAAK,CACX,UAAU;gBACV,kBAAkB,CAAC,GAAG,QAAQ,MAC1B,OAAO,CAAC,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,cAAc,CAAC,GAAG,QAAQ,MACrD,KAAK,CAAC,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC;SACjC;KACF;AACH,CAAC;AASD;;;;;;GAMG;AACH,MAAM,UAAU,aAAa,CACzB,OAAe,EAAE,OAAe,EAAE,KAAe;IACnD,IAAI,OAAO,CAAC,IAAI,GAAG,CAAC,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,4DAA4D;YAC5D,qBAAqB,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;KAC3C;IACD,IAAI,OAAO,CAAC,IAAI,GAAG,CAAC,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,4DAA4D;YAC5D,qBAAqB,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;KAC3C;IACD,IAAI,OAAO,CAAC,KAAK,KAAK,OAAO,EAAE;QAC7B,MAAM,IAAI,KAAK,CAAC,0DACZ,OAAO,CAAC,KAAK,EAAE,CAAC,CAAC;KACtB;IACD,IAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QACpB,MAAM,IAAI,KAAK,CACX,6DAA6D,KAAK,EAAE,CAAC,CAAC;KAC3E;IAED,IAAI,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;QACtB,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,EAAE;YACtB,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,KAAK,EAAE,CAAC,CAAC;SACtB;QACD,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,EAAE;YACtB,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,KAAK,EAAE,CAAC,CAAC;SACtB;KACF;IAED,mBAAmB,CAAC,KAAK,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;AAC/C,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,eAAe,CAC3B,OAAmB,EAAE,OAAmB,EACxC,KAAe;IACjB,gDAAgD;IAChD,MAAM,WAAW,GAAG,OAAO,CAAC,KAAK,CAAC,MAAM,CAAC;IACzC,MAAM,SAAS,GAAG,CAAC,WAAW,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,WAAW,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAEzE,0EAA0E;IAC1E,4EAA4E;IAC5E,oBAAoB;IACpB,MAAM,OAAO,GAAG,KAAK,CAAC,MAAM,CAAC;IAE7B,IAAI,SAAS,GAAG,CAAC,CAAC;IAClB,KAAK,IAAI,CAAC,GAAG,SAAS,EAAE,CAAC,GAAG,OAAO,EAAE,EAAE,CAAC,EAAE;QACxC,SAAS,IAAI,KAAK,CAAC,CAAC,CAAC,CAAC;KACvB;IAED,MAAM,YAAY,GAAG,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC;IACrD,MAAM,UAAU,GAAG,aAAa,CAAC,OAAO,CAAC,KAAK,CAAC,GAAG,YAAY,CAAC;IAE/D,MAAM,OAAO,GAAG,CAAC,GAAG,cAAc,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;IAClE,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAAC,CAAC;IACxC,OAAO,EAAC,SAAS,EAAE,UAAU,EAAE,SAAS,EAAE,OAAO,EAAE,UAAU,EAAC,CAAC;AACjE,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {TensorInfo} from '../kernel_registry';\nimport {Tensor} from '../tensor';\nimport {computeStrides, sizeFromShape} from '../util';\n\n/**\n * Check whether updates.shape = indices.shape[:batchDim] +\n * shape[sliceDim:]\n *\n * @param x The input tensor.\n */\nexport function validateUpdateShape(\n shape: number[], indices: Tensor, updates: Tensor) {\n const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;\n const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;\n\n const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +\n `shape[sliceDim:], got updates.shape: ${updates.shape}` +\n `, indices.shape: ${indices.shape}, shape: ${shape}` +\n `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;\n\n if (updates.rank < batchDim) {\n throw new Error(shapeError + ` update.rank < ${batchDim}. `);\n }\n if (shape.length < sliceDim + (updates.rank - batchDim)) {\n throw new Error(\n shapeError +\n ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);\n }\n if (updates.rank !== batchDim + shape.length - sliceDim) {\n throw new Error(\n shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);\n }\n for (let d = 0; d < batchDim; ++d) {\n if (updates.shape[d] !== indices.shape[d]) {\n throw new Error(\n shapeError +\n ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${\n indices.shape[d]}).`);\n }\n }\n for (let d = 0; d < updates.rank - batchDim; ++d) {\n if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {\n throw new Error(\n shapeError +\n ` updates.shape[${d + batchDim}] (${\n updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${\n shape[d + batchDim]})`);\n }\n }\n}\n\nexport interface ScatterShapeInfo {\n sliceRank: number;\n numUpdates: number;\n sliceSize: number;\n strides: number[];\n outputSize: number;\n}\n/**\n * Validate scatter nd inputs.\n *\n * @param update The tensor contains the update values.\n * @param indices The tensor contains the indices for the update values.\n * @param shape The shape of the output tensor.\n */\nexport function validateInput(\n updates: Tensor, indices: Tensor, shape: number[]) {\n if (indices.rank < 1) {\n throw new Error(\n 'tf.scatterND() expects the indices to be rank 1 or higher,' +\n ` but the rank was ${indices.rank}.`);\n }\n if (updates.rank < 1) {\n throw new Error(\n 'tf.scatterND() expects the updates to be rank 1 or higher,' +\n ` but the rank was ${updates.rank}.`);\n }\n if (indices.dtype !== 'int32') {\n throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${\n indices.dtype}`);\n }\n if (shape.length < 1) {\n throw new Error(\n `Output rank must be greater or equal to 1, but got shape: ${shape}`);\n }\n\n if (shape.length === 0) {\n if (indices.size === 0) {\n throw new Error(`Indices specified for empty output. indices shape: ${\n indices.shape}`);\n }\n if (updates.size === 0) {\n throw new Error(`Updates specified for empty output. updates shape: ${\n updates.shape}`);\n }\n }\n\n validateUpdateShape(shape, indices, updates);\n}\n\n/**\n * Calculate the shape information for the output.\n *\n * @param update The tensor contains the update values.\n * @param indices The tensor contains the indices for the update values.\n * @param shape The shape of the output tensor.\n *\n * @returns ScatterShapeInfo\n */\nexport function calculateShapes(\n updates: TensorInfo, indices: TensorInfo,\n shape: number[]): ScatterShapeInfo {\n // Calculate the number of dimensions in indices\n const indicesRank = indices.shape.length;\n const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;\n\n // Calculate the number of elements that make up each slice of our updated\n // tensor. This allows us to work with flattened tensors and copy over whole\n // slices at a time.\n const totalNd = shape.length;\n\n let sliceSize = 1;\n for (let i = sliceRank; i < totalNd; ++i) {\n sliceSize *= shape[i];\n }\n\n const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;\n const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;\n\n const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];\n const outputSize = sizeFromShape(shape);\n return {sliceRank, numUpdates, sliceSize, strides, outputSize};\n}\n"]}
\No newline at end of file