UNPKG

9.68 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2021 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { getParamValue } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'SparseFillEmptyRows': {
23 const { outputIndices, outputValues, emptyRowIndicator, reverseIndexMap } = ops.sparse.sparseFillEmptyRows(getParamValue('indices', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('denseShape', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context));
24 return [
25 outputIndices, outputValues, emptyRowIndicator, reverseIndexMap
26 ];
27 }
28 case 'SparseReshape': {
29 const { outputIndices, outputShape } = ops.sparse.sparseReshape(getParamValue('inputIndices', node, tensorMap, context), getParamValue('inputShape', node, tensorMap, context), getParamValue('newShape', node, tensorMap, context));
30 return [outputIndices, outputShape];
31 }
32 case 'SparseSegmentMean': {
33 const outputData = ops.sparse.sparseSegmentMean(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
34 return [outputData];
35 }
36 case 'SparseSegmentSum': {
37 const outputData = ops.sparse.sparseSegmentSum(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
38 return [outputData];
39 }
40 default:
41 throw TypeError(`Node type ${node.op} is not implemented`);
42 }
43};
44export const CATEGORY = 'sparse';
45//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoic3BhcnNlX2V4ZWN1dG9yLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb252ZXJ0ZXIvc3JjL29wZXJhdGlvbnMvZXhlY3V0b3JzL3NwYXJzZV9leGVjdXRvci50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFHSCxpREFBaUQ7QUFDakQsT0FBTyxLQUFLLEtBQUssTUFBTSxrREFBa0QsQ0FBQztBQU0xRSxPQUFPLEVBQUMsYUFBYSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRXRDLE1BQU0sQ0FBQyxNQUFNLFNBQVMsR0FDbEIsQ0FBQyxJQUFVLEVBQUUsU0FBMEIsRUFDdEMsT0FBeUIsRUFBRSxHQUFHLEdBQUcsS0FBSyxFQUFZLEVBQUU7SUFDbkQsUUFBUSxJQUFJLENBQUMsRUFBRSxFQUFFO1FBQ2YsS0FBSyxxQkFBcUIsQ0FBQyxDQUFDO1lBQzFCLE1BQU0sRUFDSixhQUFhLEVBQ2IsWUFBWSxFQUNaLGlCQUFpQixFQUNqQixlQUFlLEVBQ2hCLEdBQ0csR0FBRyxDQUFDLE1BQU0sQ0FBQyxtQkFBbUIsQ0FDMUIsYUFBYSxDQUFDLFNBQVMsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDckMsRUFDWixhQUFhLENBQUMsUUFBUSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLEVBQzdELGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3hDLEVBQ1osYUFBYSxDQUFDLGNBQWMsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDNUMsQ0FBQyxDQUFDO1lBQ3BCLE9BQU87Z0JBQ0wsYUFBYSxFQUFFLFlBQVksRUFBRSxpQkFBaUIsRUFBRSxlQUFlO2FBQ2hFLENBQUM7U0FDSDtRQUNELEtBQUssZUFBZSxDQUFDLENBQUM7WUFDcEIsTUFBTSxFQUFDLGFBQWEsRUFBRSxXQUFXLEVBQUMsR0FBRyxHQUFHLENBQUMsTUFBTSxDQUFDLGFBQWEsQ0FDekQsYUFBYSxDQUFDLGNBQWMsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDMUMsRUFDWixhQUFhLENBQUMsWUFBWSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLEVBQ2pFLGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWEsQ0FBQyxDQUFDO1lBQ3JFLE9BQU8sQ0FBQyxhQUFhLEVBQUUsV0FBVyxDQUFDLENBQUM7U0FDckM7UUFDRCxLQUFLLG1CQUFtQixDQUFDLENBQUM7WUFDeEIsTUFBTSxVQUFVLEdBQUcsR0FBRyxDQUFDLE1BQU0sQ0FBQyxpQkFBaUIsQ0FDM0MsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxFQUN6RCxhQUFhLENBQUMsU0FBUyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLEVBQzlELGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3hDLENBQUMsQ0FBQztZQUNsQixPQUFPLENBQUMsVUFBVSxDQUFDLENBQUM7U0FDckI7UUFDRCxLQUFLLGtCQUFrQixDQUFDLENBQUM7WUFDdkIsTUFBTSxVQUFVLEdBQUcsR0FBRyxDQUFDLE1BQU0sQ0FBQyxnQkFBZ0IsQ0FDMUMsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxFQUN6RCxhQUFhLENBQUMsU0FBUyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLEVBQzlELGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3hDLENBQUMsQ0FBQztZQUNsQixPQUFPLENBQUMsVUFBVSxDQUFDLENBQUM7U0FDckI7UUFDRDtZQUNFLE1BQU0sU0FBUyxDQUFDLGFBQWEsSUFBSSxDQUFDLEVBQUUscUJBQXFCLENBQUMsQ0FBQztLQUM5RDtBQUNILENBQUMsQ0FBQztBQUVOLE1BQU0sQ0FBQyxNQUFNLFFBQVEsR0FBRyxRQUFRLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7U2NhbGFyLCBUZW5zb3IsIFRlbnNvcjFELCBUZW5zb3IyRH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbi8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8taW1wb3J0cy1mcm9tLWRpc3RcbmltcG9ydCAqIGFzIHRmT3BzIGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZS9kaXN0L29wcy9vcHNfZm9yX2NvbnZlcnRlcic7XG5cbmltcG9ydCB7TmFtZWRUZW5zb3JzTWFwfSBmcm9tICcuLi8uLi9kYXRhL3R5cGVzJztcbmltcG9ydCB7RXhlY3V0aW9uQ29udGV4dH0gZnJvbSAnLi4vLi4vZXhlY3V0b3IvZXhlY3V0aW9uX2NvbnRleHQnO1xuaW1wb3J0IHtJbnRlcm5hbE9wRXhlY3V0b3IsIE5vZGV9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtnZXRQYXJhbVZhbHVlfSBmcm9tICcuL3V0aWxzJztcblxuZXhwb3J0IGNvbnN0IGV4ZWN1dGVPcDogSW50ZXJuYWxPcEV4ZWN1dG9yID1cbiAgICAobm9kZTogTm9kZSwgdGVuc29yTWFwOiBOYW1lZFRlbnNvcnNNYXAsXG4gICAgIGNvbnRleHQ6IEV4ZWN1dGlvbkNvbnRleHQsIG9wcyA9IHRmT3BzKTogVGVuc29yW10gPT4ge1xuICAgICAgc3dpdGNoIChub2RlLm9wKSB7XG4gICAgICAgIGNhc2UgJ1NwYXJzZUZpbGxFbXB0eVJvd3MnOiB7XG4gICAgICAgICAgY29uc3Qge1xuICAgICAgICAgICAgb3V0cHV0SW5kaWNlcyxcbiAgICAgICAgICAgIG91dHB1dFZhbHVlcyxcbiAgICAgICAgICAgIGVtcHR5Um93SW5kaWNhdG9yLFxuICAgICAgICAgICAgcmV2ZXJzZUluZGV4TWFwXG4gICAgICAgICAgfSA9XG4gICAgICAgICAgICAgIG9wcy5zcGFyc2Uuc3BhcnNlRmlsbEVtcHR5Um93cyhcbiAgICAgICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2luZGljZXMnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICAgICAgVGVuc29yMkQsXG4gICAgICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCd2YWx1ZXMnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFRlbnNvcjFELFxuICAgICAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnZGVuc2VTaGFwZScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXNcbiAgICAgICAgICAgICAgICAgICAgICBUZW5zb3IxRCxcbiAgICAgICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2RlZmF1bHRWYWx1ZScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXNcbiAgICAgICAgICAgICAgICAgICAgICBTY2FsYXIpO1xuICAgICAgICAgIHJldHVybiBbXG4gICAgICAgICAgICBvdXRwdXRJbmRpY2VzLCBvdXRwdXRWYWx1ZXMsIGVtcHR5Um93SW5kaWNhdG9yLCByZXZlcnNlSW5kZXhNYXBcbiAgICAgICAgICBdO1xuICAgICAgICB9XG4gICAgICAgIGNhc2UgJ1NwYXJzZVJlc2hhcGUnOiB7XG4gICAgICAgICAgY29uc3Qge291dHB1dEluZGljZXMsIG91dHB1dFNoYXBlfSA9IG9wcy5zcGFyc2Uuc3BhcnNlUmVzaGFwZShcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnaW5wdXRJbmRpY2VzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhc1xuICAgICAgICAgICAgICAgICAgVGVuc29yMkQsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2lucHV0U2hhcGUnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFRlbnNvcjFELFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCduZXdTaGFwZScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMUQpO1xuICAgICAgICAgIHJldHVybiBbb3V0cHV0SW5kaWNlcywgb3V0cHV0U2hhcGVdO1xuICAgICAgICB9XG4gICAgICAgIGNhc2UgJ1NwYXJzZVNlZ21lbnRNZWFuJzoge1xuICAgICAgICAgIGNvbnN0IG91dHB1dERhdGEgPSBvcHMuc3BhcnNlLnNwYXJzZVNlZ21lbnRNZWFuKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdkYXRhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2luZGljZXMnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFRlbnNvcjFELFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdzZWdtZW50SWRzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhc1xuICAgICAgICAgICAgICAgICAgVGVuc29yMUQpO1xuICAgICAgICAgIHJldHVybiBbb3V0cHV0RGF0YV07XG4gICAgICAgIH1cbiAgICAgICAgY2FzZSAnU3BhcnNlU2VnbWVudFN1bSc6IHtcbiAgICAgICAgICBjb25zdCBvdXRwdXREYXRhID0gb3BzLnNwYXJzZS5zcGFyc2VTZWdtZW50U3VtKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdkYXRhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2luZGljZXMnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFRlbnNvcjFELFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdzZWdtZW50SWRzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhc1xuICAgICAgICAgICAgICAgICAgVGVuc29yMUQpO1xuICAgICAgICAgIHJldHVybiBbb3V0cHV0RGF0YV07XG4gICAgICAgIH1cbiAgICAgICAgZGVmYXVsdDpcbiAgICAgICAgICB0aHJvdyBUeXBlRXJyb3IoYE5vZGUgdHlwZSAke25vZGUub3B9IGlzIG5vdCBpbXBsZW1lbnRlZGApO1xuICAgICAgfVxuICAgIH07XG5cbmV4cG9ydCBjb25zdCBDQVRFR09SWSA9ICdzcGFyc2UnO1xuIl19
\No newline at end of file