UNPKG

12.7 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2018 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17// tslint:disable-next-line: no-imports-from-dist
18import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
19import { getParamValue } from './utils';
20export const executeOp = (node, tensorMap, context, ops = tfOps) => {
21 switch (node.op) {
22 case 'BatchMatMul':
23 case 'BatchMatMulV2':
24 case 'MatMul':
25 return [ops.matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
26 case 'Einsum':
27 return [ops.einsum(getParamValue('equation', node, tensorMap, context), ...getParamValue('tensors', node, tensorMap, context))];
28 case 'Transpose':
29 return [ops.transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
30 case '_FusedMatMul':
31 const [extraOp, activationFunc] = getParamValue('fusedOps', node, tensorMap, context);
32 const isBiasAdd = extraOp === 'biasadd';
33 const isPrelu = activationFunc === 'prelu';
34 const numArgs = getParamValue('numArgs', node, tensorMap, context);
35 const leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
36 if (isBiasAdd) {
37 if (isPrelu && numArgs !== 2) {
38 throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' +
39 'extra arguments: bias and alpha.');
40 }
41 if (!isPrelu && numArgs !== 1) {
42 throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
43 }
44 }
45 const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context);
46 return [ops.fused.matMul({
47 a: getParamValue('a', node, tensorMap, context),
48 b: getParamValue('b', node, tensorMap, context),
49 transposeA: getParamValue('transposeA', node, tensorMap, context),
50 transposeB: getParamValue('transposeB', node, tensorMap, context),
51 bias: biasArg,
52 activation: activationFunc,
53 preluActivationWeights: preluArg,
54 leakyreluAlpha
55 })];
56 case 'MatrixBandPart':
57 return [ops.linalg.bandPart(getParamValue('a', node, tensorMap, context), getParamValue('numLower', node, tensorMap, context), getParamValue('numUpper', node, tensorMap, context))];
58 default:
59 throw TypeError(`Node type ${node.op} is not implemented`);
60 }
61};
62export const CATEGORY = 'matrices';
63//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibWF0cmljZXNfZXhlY3V0b3IuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvbnZlcnRlci9zcmMvb3BlcmF0aW9ucy9leGVjdXRvcnMvbWF0cmljZXNfZXhlY3V0b3IudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsaURBQWlEO0FBQ2pELE9BQU8sS0FBSyxLQUFLLE1BQU0sa0RBQWtELENBQUM7QUFNMUUsT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUV0QyxNQUFNLENBQUMsTUFBTSxTQUFTLEdBQ2xCLENBQUMsSUFBVSxFQUFFLFNBQTBCLEVBQUUsT0FBeUIsRUFDakUsR0FBRyxHQUFHLEtBQUssRUFBWSxFQUFFO0lBQ3hCLFFBQVEsSUFBSSxDQUFDLEVBQUUsRUFBRTtRQUNmLEtBQUssYUFBYSxDQUFDO1FBQ25CLEtBQUssZUFBZSxDQUFDO1FBQ3JCLEtBQUssUUFBUTtZQUNYLE9BQU8sQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUNkLGFBQWEsQ0FBQyxHQUFHLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWEsRUFDeEQsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxFQUN4RCxhQUFhLENBQUMsWUFBWSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFZLEVBQ2hFLGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3pDLENBQUMsQ0FBQyxDQUFDO1FBRXBCLEtBQUssUUFBUTtZQUNYLE9BQU8sQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUNkLGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsRUFDN0QsR0FBRyxhQUFhLENBQUMsU0FBUyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUN4QyxDQUFDLENBQUMsQ0FBQztRQUVyQixLQUFLLFdBQVc7WUFDZCxPQUFPLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FDakIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxFQUN0RCxhQUFhLENBQUMsTUFBTSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhLENBQUMsQ0FBQyxDQUFDO1FBRXBFLEtBQUssY0FBYztZQUNqQixNQUFNLENBQUMsT0FBTyxFQUFFLGNBQWMsQ0FBQyxHQUMxQixhQUFhLENBQUMsVUFBVSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFjLENBQUM7WUFFdEUsTUFBTSxTQUFTLEdBQUcsT0FBTyxLQUFLLFNBQVMsQ0FBQztZQUN4QyxNQUFNLE9BQU8sR0FBRyxjQUFjLEtBQUssT0FBTyxDQUFDO1lBRTNDLE1BQU0sT0FBTyxHQUNSLGFBQWEsQ0FBQyxTQUFTLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVksQ0FBQztZQUNuRSxNQUFNLGNBQWMsR0FDaEIsYUFBYSxDQUFDLGdCQUFnQixFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUNsRCxDQUFDO1lBRVgsSUFBSSxTQUFTLEVBQUU7Z0JBQ2IsSUFBSSxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDNUIsTUFBTSxJQUFJLEtBQUssQ0FDWCxvREFBb0Q7d0JBQ3BELGtDQUFrQyxDQUFDLENBQUM7aUJBQ3pDO2dCQUNELElBQUksQ0FBQyxPQUFPLElBQUksT0FBTyxLQUFLLENBQUMsRUFBRTtvQkFDN0IsTUFBTSxJQUFJLEtBQUssQ0FDWCwrREFBK0QsQ0FBQyxDQUFDO2lCQUN0RTthQUNGO1lBQ0QsTUFBTSxDQUFDLE9BQU8sRUFBRSxRQUFRLENBQUMsR0FDckIsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxDQUFDO1lBQ2hFLE9BQU8sQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztvQkFDdkIsQ0FBQyxFQUFFLGFBQWEsQ0FBQyxHQUFHLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWE7b0JBQzNELENBQUMsRUFBRSxhQUFhLENBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFhO29CQUMzRCxVQUFVLEVBQUUsYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FDckQ7b0JBQ1gsVUFBVSxFQUFFLGFBQWEsQ0FBQyxZQUFZLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQ3JEO29CQUNYLElBQUksRUFBRSxPQUFPO29CQUNiLFVBQVUsRUFBRSxjQUF3QztvQkFDcEQsc0JBQXNCLEVBQUUsUUFBUTtvQkFDaEMsY0FBYztpQkFDZixDQUFDLENBQUMsQ0FBQztRQUVOLEtBQUssZ0JBQWdCO1lBQ25CLE9BQU8sQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLFFBQVEsQ0FDdkIsYUFBYSxDQUFDLEdBQUcsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxFQUN4RCxhQUFhLENBQUMsVUFBVSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFXLEVBQzdELGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsQ0FBQyxDQUFDLENBQUM7UUFFdEU7WUFDRSxNQUFNLFNBQVMsQ0FBQyxhQUFhLElBQUksQ0FBQyxFQUFFLHFCQUFxQixDQUFDLENBQUM7S0FDOUQ7QUFDSCxDQUFDLENBQUM7QUFFTixNQUFNLENBQUMsTUFBTSxRQUFRLEdBQUcsVUFBVSxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1NjYWxhciwgVGVuc29yLCBUZW5zb3IyRH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbi8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8taW1wb3J0cy1mcm9tLWRpc3RcbmltcG9ydCAqIGFzIHRmT3BzIGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZS9kaXN0L29wcy9vcHNfZm9yX2NvbnZlcnRlcic7XG5cbmltcG9ydCB7TmFtZWRUZW5zb3JzTWFwfSBmcm9tICcuLi8uLi9kYXRhL3R5cGVzJztcbmltcG9ydCB7RXhlY3V0aW9uQ29udGV4dH0gZnJvbSAnLi4vLi4vZXhlY3V0b3IvZXhlY3V0aW9uX2NvbnRleHQnO1xuaW1wb3J0IHtJbnRlcm5hbE9wRXhlY3V0b3IsIE5vZGV9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtnZXRQYXJhbVZhbHVlfSBmcm9tICcuL3V0aWxzJztcblxuZXhwb3J0IGNvbnN0IGV4ZWN1dGVPcDogSW50ZXJuYWxPcEV4ZWN1dG9yID1cbiAgICAobm9kZTogTm9kZSwgdGVuc29yTWFwOiBOYW1lZFRlbnNvcnNNYXAsIGNvbnRleHQ6IEV4ZWN1dGlvbkNvbnRleHQsXG4gICAgIG9wcyA9IHRmT3BzKTogVGVuc29yW10gPT4ge1xuICAgICAgc3dpdGNoIChub2RlLm9wKSB7XG4gICAgICAgIGNhc2UgJ0JhdGNoTWF0TXVsJzpcbiAgICAgICAgY2FzZSAnQmF0Y2hNYXRNdWxWMic6XG4gICAgICAgIGNhc2UgJ01hdE11bCc6XG4gICAgICAgICAgcmV0dXJuIFtvcHMubWF0TXVsKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnYicsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIGJvb2xlYW4sXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICBib29sZWFuKV07XG5cbiAgICAgICAgY2FzZSAnRWluc3VtJzpcbiAgICAgICAgICByZXR1cm4gW29wcy5laW5zdW0oXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2VxdWF0aW9uJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBzdHJpbmcsXG4gICAgICAgICAgICAgIC4uLmdldFBhcmFtVmFsdWUoJ3RlbnNvcnMnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgICBUZW5zb3JbXSldO1xuXG4gICAgICAgIGNhc2UgJ1RyYW5zcG9zZSc6XG4gICAgICAgICAgcmV0dXJuIFtvcHMudHJhbnNwb3NlKFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCd4Jywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IsXG4gICAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3Blcm0nLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIG51bWJlcltdKV07XG5cbiAgICAgICAgY2FzZSAnX0Z1c2VkTWF0TXVsJzpcbiAgICAgICAgICBjb25zdCBbZXh0cmFPcCwgYWN0aXZhdGlvbkZ1bmNdID1cbiAgICAgICAgICAgICAgKGdldFBhcmFtVmFsdWUoJ2Z1c2VkT3BzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBzdHJpbmdbXSk7XG5cbiAgICAgICAgICBjb25zdCBpc0JpYXNBZGQgPSBleHRyYU9wID09PSAnYmlhc2FkZCc7XG4gICAgICAgICAgY29uc3QgaXNQcmVsdSA9IGFjdGl2YXRpb25GdW5jID09PSAncHJlbHUnO1xuXG4gICAgICAgICAgY29uc3QgbnVtQXJncyA9XG4gICAgICAgICAgICAgIChnZXRQYXJhbVZhbHVlKCdudW1BcmdzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBudW1iZXIpO1xuICAgICAgICAgIGNvbnN0IGxlYWt5cmVsdUFscGhhID1cbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnbGVha3lyZWx1QWxwaGEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgIG51bWJlcjtcblxuICAgICAgICAgIGlmIChpc0JpYXNBZGQpIHtcbiAgICAgICAgICAgIGlmIChpc1ByZWx1ICYmIG51bUFyZ3MgIT09IDIpIHtcbiAgICAgICAgICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgICAgICAgICAgJ0Z1c2VkIE1hdE11bCB3aXRoIEJpYXNBZGQgYW5kIFByZWx1IG11c3QgaGF2ZSB0d28gJyArXG4gICAgICAgICAgICAgICAgICAnZXh0cmEgYXJndW1lbnRzOiBiaWFzIGFuZCBhbHBoYS4nKTtcbiAgICAgICAgICAgIH1cbiAgICAgICAgICAgIGlmICghaXNQcmVsdSAmJiBudW1BcmdzICE9PSAxKSB7XG4gICAgICAgICAgICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgICAgICAgICAgICdGdXNlZCBNYXRNdWwgd2l0aCBCaWFzQWRkIG11c3QgaGF2ZSBvbmUgZXh0cmEgYXJndW1lbnQ6IGJpYXMuJyk7XG4gICAgICAgICAgICB9XG4gICAgICAgICAgfVxuICAgICAgICAgIGNvbnN0IFtiaWFzQXJnLCBwcmVsdUFyZ10gPVxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhcmdzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3JbXTtcbiAgICAgICAgICByZXR1cm4gW29wcy5mdXNlZC5tYXRNdWwoe1xuICAgICAgICAgICAgYTogZ2V0UGFyYW1WYWx1ZSgnYScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yMkQsXG4gICAgICAgICAgICBiOiBnZXRQYXJhbVZhbHVlKCdiJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgIHRyYW5zcG9zZUE6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUEnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIHRyYW5zcG9zZUI6IGdldFBhcmFtVmFsdWUoJ3RyYW5zcG9zZUInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzXG4gICAgICAgICAgICAgICAgYm9vbGVhbixcbiAgICAgICAgICAgIGJpYXM6IGJpYXNBcmcsXG4gICAgICAgICAgICBhY3RpdmF0aW9uOiBhY3RpdmF0aW9uRnVuYyBhcyB0Zk9wcy5mdXNlZC5BY3RpdmF0aW9uLFxuICAgICAgICAgICAgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0czogcHJlbHVBcmcsXG4gICAgICAgICAgICBsZWFreXJlbHVBbHBoYVxuICAgICAgICAgIH0pXTtcblxuICAgICAgICBjYXNlICdNYXRyaXhCYW5kUGFydCc6XG4gICAgICAgICAgcmV0dXJuIFtvcHMubGluYWxnLmJhbmRQYXJ0KFxuICAgICAgICAgICAgICBnZXRQYXJhbVZhbHVlKCdhJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3IyRCxcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnbnVtTG93ZXInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFNjYWxhcixcbiAgICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgnbnVtVXBwZXInLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIFNjYWxhcildO1xuXG4gICAgICAgIGRlZmF1bHQ6XG4gICAgICAgICAgdGhyb3cgVHlwZUVycm9yKGBOb2RlIHR5cGUgJHtub2RlLm9wfSBpcyBub3QgaW1wbGVtZW50ZWRgKTtcbiAgICAgIH1cbiAgICB9O1xuXG5leHBvcnQgY29uc3QgQ0FURUdPUlkgPSAnbWF0cmljZXMnO1xuIl19
\No newline at end of file