UNPKG

9.88 kBSource Map (JSON)View Raw
1{"version":3,"file":"kernel_registry.js","sourceRoot":"","sources":["../src/kernel_registry.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAC,GAAG,EAAC,MAAM,eAAe,CAAC;AAElC,OAAO,EAAC,SAAS,EAAC,MAAM,eAAe,CAAC;AAKxC,MAAM,cAAc,GAChB,SAAS,CAAC,gBAAgB,EAAE,GAAG,EAAE,CAAC,IAAI,GAAG,EAAwB,CAAC,CAAC;AACvE,MAAM,YAAY,GACd,SAAS,CAAC,cAAc,EAAE,GAAG,EAAE,CAAC,IAAI,GAAG,EAAsB,CAAC,CAAC;AA8DnE;;;;;GAKG;AACH,MAAM,UAAU,SAAS,CACrB,UAAkB,EAAE,WAAmB;IACzC,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC7C,OAAO,cAAc,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;AACjC,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,WAAW,CAAC,UAAkB;IAC5C,OAAO,YAAY,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;AACtC,CAAC;AAED,MAAM,UAAU,oBAAoB,CAAC,WAAmB;IACtD,MAAM,EAAE,GAAG,cAAc,CAAC,OAAO,EAAE,CAAC;IACpC,MAAM,MAAM,GAAmB,EAAE,CAAC;IAElC,OAAO,IAAI,EAAE;QACX,MAAM,EAAC,IAAI,EAAE,KAAK,EAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC;QAChC,IAAI,IAAI,EAAE;YACR,MAAM;SACP;QACD,MAAM,CAAC,GAAG,EAAE,MAAM,CAAC,GAAG,KAAK,CAAC;QAC5B,MAAM,CAAC,OAAO,EAAG,GAAG,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;QACnC,IAAI,OAAO,KAAK,WAAW,EAAE;YAC3B,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;SACrB;KACF;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,cAAc,CAAC,MAAoB;IACjD,MAAM,EAAC,UAAU,EAAE,WAAW,EAAC,GAAG,MAAM,CAAC;IACzC,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC7C,IAAI,cAAc,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QAC3B,OAAO,CAAC,IAAI,CACR,eAAe,UAAU,gBAAgB;YACzC,IAAI,WAAW,yBAAyB,CAAC,CAAC;KAC/C;IACD,cAAc,CAAC,GAAG,CAAC,GAAG,EAAE,MAAM,CAAC,CAAC;AAClC,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,gBAAgB,CAAC,MAAkB;IACjD,MAAM,EAAC,UAAU,EAAC,GAAG,MAAM,CAAC;IAE5B,IAAI,YAAY,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE;QAChC,sEAAsE;QACtE,iBAAiB;QACjB,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE;YAC1B,OAAO,CAAC,IAAI,CAAC,gCAAgC,UAAU,GAAG,CAAC,CAAC;SAC7D;KACF;IACD,YAAY,CAAC,GAAG,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;AACvC,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,gBAAgB,CAC5B,UAAkB,EAAE,WAAmB;IACzC,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC7C,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QAC5B,MAAM,IAAI,KAAK,CACX,eAAe,UAAU,gBAAgB;YACzC,IAAI,WAAW,qBAAqB,CAAC,CAAC;KAC3C;IACD,cAAc,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;AAC7B,CAAC;AAED,gEAAgE;AAChE,MAAM,UAAU,kBAAkB,CAAC,UAAkB;IACnD,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE;QACjC,MAAM,IAAI,KAAK,CACX,iBAAiB,UAAU,iCAAiC,CAAC,CAAC;KACnE;IACD,YAAY,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC;AAClC,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,qBAAqB,CACjC,qBAA6B,EAAE,cAAsB;IACvD,MAAM,OAAO,GAAG,oBAAoB,CAAC,qBAAqB,CAAC,CAAC;IAC5D,OAAO,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE;QAC7B,MAAM,eAAe,GACjB,MAAM,CAAC,MAAM,CAAC,EAAE,EAAE,YAAY,EAAE,EAAC,WAAW,EAAE,cAAc,EAAC,CAAC,CAAC;QACnE,cAAc,CAAC,eAAe,CAAC,CAAC;IAClC,CAAC,CAAC,CAAC;AACL,CAAC;AAED,SAAS,OAAO,CAAC,UAAkB,EAAE,WAAmB;IACtD,OAAO,GAAG,WAAW,IAAI,UAAU,EAAE,CAAC;AACxC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {env} from './environment';\n\nimport {getGlobal} from './global_util';\nimport {NamedGradientMap} from './tape';\nimport {Tensor} from './tensor';\nimport {DataType, RecursiveArray} from './types';\n\nconst kernelRegistry =\n getGlobal('kernelRegistry', () => new Map<string, KernelConfig>());\nconst gradRegistry =\n getGlobal('gradRegistry', () => new Map<string, GradConfig>());\n\nexport type DataId = object;\n\ntype AttributeValue =\n number|number[]|boolean|boolean[]|string|string[]|NamedAttrMap;\n\n/** These are extra non-tensor/primitive params passed to kernel functions. */\nexport type Attribute = AttributeValue|RecursiveArray<AttributeValue>;\n\n/** Specifies the code to run when executing a kernel. */\nexport type KernelFunc = (params: {\n inputs: NamedTensorInfoMap,\n backend: {},\n attrs?: NamedAttrMap,\n}) => TensorInfo|TensorInfo[];\n\n/** The function to run when computing a gradient during backprop. */\nexport type GradFunc =\n (dy: Tensor|Tensor[], saved: Tensor[], attrs: NamedAttrMap) =>\n NamedGradientMap;\n\n/** Function that gets called after the backend initializes. */\nexport type KernelSetupFunc = (backend: {}) => void;\n/** Function that gets called right before the backend is disposed. */\nexport type KernelDisposeFunc = KernelSetupFunc;\n\n/** Config object for registering a kernel in the global registry. */\nexport interface KernelConfig {\n kernelName: string;\n backendName: string;\n kernelFunc: KernelFunc;\n setupFunc?: KernelSetupFunc;\n disposeFunc?: KernelDisposeFunc;\n}\n\n/** Config object for registering a gradient in the global registry. */\nexport interface GradConfig {\n kernelName: string;\n inputsToSave?: string[];\n // When saveAllInputs is true, all inputs will be saved. Only use this flag\n // if inputs is an array of Tensors.\n saveAllInputs?: boolean;\n outputsToSave?: boolean[];\n gradFunc: GradFunc;\n}\n\n/** Holds metadata for a given tensor. */\nexport interface TensorInfo {\n dataId: DataId;\n shape: number[];\n dtype: DataType;\n}\n\nexport interface NamedTensorInfoMap {\n [name: string]: TensorInfo;\n}\n\nexport interface NamedAttrMap {\n [name: string]: Attribute;\n}\n\n/**\n * Returns the kernel function (code) associated with the provided names.\n *\n * @param kernelName The official name of the kernel.\n * @param backendName The official name of the backend.\n */\nexport function getKernel(\n kernelName: string, backendName: string): KernelConfig {\n const key = makeKey(kernelName, backendName);\n return kernelRegistry.get(key);\n}\n\n/**\n * Returns the registered gradient info associated with the provided kernel.\n * @param kernelName The official TF kernel name.\n */\nexport function getGradient(kernelName: string): GradConfig {\n return gradRegistry.get(kernelName);\n}\n\nexport function getKernelsForBackend(backendName: string): KernelConfig[] {\n const it = kernelRegistry.entries();\n const result: KernelConfig[] = [];\n\n while (true) {\n const {done, value} = it.next();\n if (done) {\n break;\n }\n const [key, config] = value;\n const [backend, ] = key.split('_');\n if (backend === backendName) {\n result.push(config);\n }\n }\n return result;\n}\n\n/**\n * Registers the function (forward pass) for the kernel in a global registry.\n *\n * @param config A config object with the following properties:\n * - `kernelName` The official name of the kernel.\n * - `backendName` The official name of the backend.\n * - `kernelFunc` The function to run during the forward pass of the kernel.\n * - `setupFunc` Optional. Gets called once, after the backend initializes.\n * - `disposeFunc` Optional. Gets called once, right before the backend is\n * disposed.\n */\nexport function registerKernel(config: KernelConfig) {\n const {kernelName, backendName} = config;\n const key = makeKey(kernelName, backendName);\n if (kernelRegistry.has(key)) {\n console.warn(\n `The kernel '${kernelName}' for backend ` +\n `'${backendName}' is already registered`);\n }\n kernelRegistry.set(key, config);\n}\n\n/**\n * Registers a gradient function for a given kernel in the global registry,\n * to be used during the back-propagation of that kernel.\n *\n * @param config An object with the following properties:\n * - `kernelName` The name of the kernel that the gradient function is for.\n * - `gradFunc` The function to run during back-propagation.\n */\nexport function registerGradient(config: GradConfig) {\n const {kernelName} = config;\n\n if (gradRegistry.has(kernelName)) {\n // TODO (yassogba) after 3.0 assess whether we need to keep this gated\n // to debug mode.\n if (env().getBool('DEBUG')) {\n console.warn(`Overriding the gradient for '${kernelName}'`);\n }\n }\n gradRegistry.set(kernelName, config);\n}\n\n/**\n * Removes the kernel function from the registry.\n *\n * @param kernelName The official name of the kernel.\n * @param backendName The official name of the backend.\n *\n */\nexport function unregisterKernel(\n kernelName: string, backendName: string): void {\n const key = makeKey(kernelName, backendName);\n if (!kernelRegistry.has(key)) {\n throw new Error(\n `The kernel '${kernelName}' for backend ` +\n `'${backendName}' is not registered`);\n }\n kernelRegistry.delete(key);\n}\n\n/** Removes the registered gradient from the global registry. */\nexport function unregisterGradient(kernelName: string): void {\n if (!gradRegistry.has(kernelName)) {\n throw new Error(\n `The gradient '${kernelName}' for backend is not registered`);\n }\n gradRegistry.delete(kernelName);\n}\n\n/**\n * Finds kernels that have already been registered to a backend and re-registers\n * them for a new backend. Useful for registering custom backends.\n * @param registeredBackendName Already registered backend.\n * @param newBackendName New backend.\n */\nexport function copyRegisteredKernels(\n registeredBackendName: string, newBackendName: string): void {\n const kernels = getKernelsForBackend(registeredBackendName);\n kernels.forEach(kernelConfig => {\n const newKernelConfig =\n Object.assign({}, kernelConfig, {backendName: newBackendName});\n registerKernel(newKernelConfig);\n });\n}\n\nfunction makeKey(kernelName: string, backendName: string) {\n return `${backendName}_${kernelName}`;\n}\n"]}
\No newline at end of file