1 | /**
|
2 | * @license
|
3 | * Copyright 2018 Google LLC
|
4 | *
|
5 | * Use of this source code is governed by an MIT-style
|
6 | * license that can be found in the LICENSE file or at
|
7 | * https://opensource.org/licenses/MIT.
|
8 | * =============================================================================
|
9 | */
|
10 | import * as tfc from '@tensorflow/tfjs-core';
|
11 | import { Dataset, LazyIterator } from './dataset_stub';
|
12 | function mergeBatchSizeAndShape(batchSize, shape) {
|
13 | if (Array.isArray(shape)) {
|
14 | return [batchSize].concat(shape);
|
15 | }
|
16 | else {
|
17 | const output = {};
|
18 | for (const name in shape) {
|
19 | output[name] = [batchSize].concat(shape[name]);
|
20 | }
|
21 | return output;
|
22 | }
|
23 | }
|
24 | function generateRandomTensorContainer(shape) {
|
25 | let output;
|
26 | if (Array.isArray(shape)) {
|
27 | output = tfc.randomNormal(shape);
|
28 | }
|
29 | else {
|
30 | output = {};
|
31 | for (const name in shape) {
|
32 | output[name] = tfc.randomNormal(shape[name]);
|
33 | }
|
34 | }
|
35 | return output;
|
36 | }
|
37 | class FakeNumericIterator extends LazyIterator {
|
38 | constructor(args) {
|
39 | super();
|
40 | this.tensorIndex = 0;
|
41 | this.xBatchShape = mergeBatchSizeAndShape(args.batchSize, args.xShape);
|
42 | this.yBatchShape = mergeBatchSizeAndShape(args.batchSize, args.yShape);
|
43 | this.numBatches = args.numBatches;
|
44 | this.batchCount = 0;
|
45 | this.xTensorsFunc = args.xTensorsFunc;
|
46 | this.yTensorsFunc = args.yTensorsFunc;
|
47 | // Sanity check on the preset tensors.
|
48 | tfc.util.assert(this.xTensorsFunc == null && this.yTensorsFunc == null ||
|
49 | this.xTensorsFunc != null && this.yTensorsFunc != null, () => 'presetXTensors and presetYTensors must be both null/undefined ' +
|
50 | 'or both set.');
|
51 | }
|
52 | async next() {
|
53 | const done = ++this.batchCount > this.numBatches;
|
54 | if (done) {
|
55 | return { done, value: null };
|
56 | }
|
57 | if (this.xTensorsFunc == null) {
|
58 | // Generate data randomly.
|
59 | return {
|
60 | done,
|
61 | value: done ? null : {
|
62 | xs: generateRandomTensorContainer(this.xBatchShape),
|
63 | ys: generateRandomTensorContainer(this.yBatchShape)
|
64 | }
|
65 | };
|
66 | }
|
67 | else {
|
68 | // Use preset tensors.
|
69 | if ((this.batchCount - 1) % this.numBatches === 0) {
|
70 | this.xTensorValues = this.xTensorsFunc();
|
71 | this.yTensorValues = this.yTensorsFunc();
|
72 | this.tensorIndex = 0;
|
73 | }
|
74 | const index = this.tensorIndex++;
|
75 | let xs;
|
76 | if (Array.isArray(this.xTensorValues)) {
|
77 | xs = this.xTensorValues[index];
|
78 | tfc.util.assert(tfc.util.arraysEqual(xs.shape, this.xBatchShape), () => `Shape mismatch: expected: ${JSON.stringify(this.xBatchShape)}; ` +
|
79 | `actual: ${JSON.stringify(xs.shape)}`);
|
80 | }
|
81 | else {
|
82 | xs = {};
|
83 | for (const key in this.xTensorValues) {
|
84 | xs[key] = this.xTensorValues[key][index];
|
85 | tfc.util.assert(tfc.util.arraysEqual(xs[key].shape, this.xBatchShape), () => `Shape mismatch: expected: ${JSON.stringify(this.xBatchShape)}; ` +
|
86 | `actual: ${JSON.stringify(xs.shape)}`);
|
87 | }
|
88 | }
|
89 | let ys;
|
90 | if (Array.isArray(this.yTensorValues)) {
|
91 | // Get preset ys tensors for single-output models.
|
92 | ys = this.yTensorValues[index];
|
93 | tfc.util.assert(tfc.util.arraysEqual(ys.shape, this.yBatchShape), () => `Shape mismatch: expected: ${JSON.stringify(this.yBatchShape)}; ` +
|
94 | `actual: ${JSON.stringify(ys.shape)}`);
|
95 | }
|
96 | else {
|
97 | // Get preset ys tensors for multi-output models.
|
98 | ys = {};
|
99 | this.yBatchShape = this.yBatchShape;
|
100 | for (const key in this.yTensorValues) {
|
101 | ys[key] = this.yTensorValues[key][index];
|
102 | tfc.util.assert(tfc.util.arraysEqual(ys[key].shape, this.yBatchShape[key]), () => `Shape mismatch: expected: ${JSON.stringify(this.yBatchShape)}; ` +
|
103 | `actual: ${JSON.stringify(ys[key].shape)}`);
|
104 | }
|
105 | }
|
106 | return { done, value: { xs, ys } };
|
107 | }
|
108 | }
|
109 | }
|
110 | /**
|
111 | * A fake dataset with configurable feature and target shapes.
|
112 | *
|
113 | * The batch size and # of batches are also configurable.
|
114 | *
|
115 | * The iterator from the dataset always generate random-normal float32 values.
|
116 | */
|
117 | export class FakeNumericDataset extends Dataset {
|
118 | constructor(args) {
|
119 | super();
|
120 | this.args = args;
|
121 | tfc.util.assert(args.batchSize > 0 && Number.isInteger(args.batchSize), () => `batchSize must be a positive integer, but got ${args.batchSize}`);
|
122 | tfc.util.assert(args.numBatches > 0 && Number.isInteger(args.numBatches), () => `numBatches must be positive integer, but got ${args.numBatches}`);
|
123 | this.size = args.numBatches;
|
124 | }
|
125 | async iterator() {
|
126 | return new FakeNumericIterator(this.args);
|
127 | }
|
128 | }
|
129 | // We can't use Dataset.map(...) because we don't depend on tfjs-data here,
|
130 | // so we manually transform the above {xs, ys} dataset to the [xs, ys] form.
|
131 | export class FakeNumericDatasetLegacyArrayForm extends Dataset {
|
132 | constructor(args) {
|
133 | super();
|
134 | this.args = args;
|
135 | this.ds = new FakeNumericDataset(args);
|
136 | }
|
137 | async iterator() {
|
138 | const it = await this.ds.iterator();
|
139 | return new FakeNumericIteratorLegacyArrayForm(it);
|
140 | }
|
141 | }
|
142 | class FakeNumericIteratorLegacyArrayForm extends LazyIterator {
|
143 | constructor(it) {
|
144 | super();
|
145 | this.it = it;
|
146 | }
|
147 | async next() {
|
148 | const result = await this.it.next();
|
149 | return {
|
150 | done: result.done,
|
151 | value: result.value == null ? null : [result.value.xs, result.value.ys]
|
152 | };
|
153 | }
|
154 | }
|
155 | //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"dataset_fakes.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/engine/dataset_fakes.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAK7C,OAAO,EAAC,OAAO,EAAE,YAAY,EAAC,MAAM,gBAAgB,CAAC;AA2CrD,SAAS,sBAAsB,CAC3B,SAAiB,EAAE,KAAoC;IAEzD,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACxB,OAAO,CAAC,SAAS,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;KAClC;SAAM;QACL,MAAM,MAAM,GAA4B,EAAE,CAAC;QAC3C,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;YACxB,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC;SAChD;QACD,OAAO,MAAM,CAAC;KACf;AACH,CAAC;AAED,SAAS,6BAA6B,CAAC,KAAoC;IAEzE,IAAI,MAA+C,CAAC;IACpD,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACxB,MAAM,GAAG,GAAG,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;KAClC;SAAM;QACL,MAAM,GAAG,EAAE,CAAC;QACZ,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;YACxB,MAAM,CAAC,IAAI,CAAC,GAAG,GAAG,CAAC,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC;SAC9C;KACF;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,mBAAoB,SAAQ,YAA+B;IAW/D,YAAY,IAAqB;QAC/B,KAAK,EAAE,CAAC;QAHF,gBAAW,GAAG,CAAC,CAAC;QAItB,IAAI,CAAC,WAAW,GAAG,sBAAsB,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,MAAM,CAAC,CAAC;QACvE,IAAI,CAAC,WAAW,GAAG,sBAAsB,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,MAAM,CAAC,CAAC;QACvE,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;QAClC,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;QACpB,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,YAAY,CAAC;QACtC,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,YAAY,CAAC;QAEtC,sCAAsC;QACtC,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,IAAI,CAAC,YAAY,IAAI,IAAI;YAClD,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,IAAI,CAAC,YAAY,IAAI,IAAI,EAC1D,GAAG,EAAE,CAAC,gEAAgE;YAClE,cAAc,CAAC,CAAC;IAC1B,CAAC;IAED,KAAK,CAAC,IAAI;QACR,MAAM,IAAI,GAAG,EAAE,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;QACjD,IAAI,IAAI,EAAE;YACR,OAAO,EAAC,IAAI,EAAE,KAAK,EAAE,IAAI,EAAC,CAAC;SAC5B;QACD,IAAI,IAAI,CAAC,YAAY,IAAI,IAAI,EAAE;YAC7B,0BAA0B;YAC1B,OAAO;gBACL,IAAI;gBACJ,KAAK,EAAE,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;oBACnB,EAAE,EAAE,6BAA6B,CAAC,IAAI,CAAC,WAAW,CAAC;oBACnD,EAAE,EAAE,6BAA6B,CAAC,IAAI,CAAC,WAAW,CAAC;iBACpD;aACF,CAAC;SACH;aAAM;YACL,sBAAsB;YACtB,IAAI,CAAC,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,UAAU,KAAK,CAAC,EAAE;gBACjD,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,YAAY,EAAE,CAAC;gBACzC,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,YAAY,EAAE,CAAC;gBACzC,IAAI,CAAC,WAAW,GAAG,CAAC,CAAC;aACtB;YACD,MAAM,KAAK,GAAG,IAAI,CAAC,WAAW,EAAE,CAAC;YAEjC,IAAI,EAA2C,CAAC;YAChD,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBACrC,EAAE,GAAG,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;gBAC/B,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,GAAG,CAAC,IAAI,CAAC,WAAW,CAAC,EAAE,CAAC,KAAK,EAAE,IAAI,CAAC,WAAoB,CAAC,EACzD,GAAG,EAAE,CAAC,6BACI,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,WAAW,CAAC,IAAI;oBAC1C,WAAW,IAAI,CAAC,SAAS,CAAE,EAAiB,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;aAChE;iBAAM;gBACL,EAAE,GAAG,EAAE,CAAC;gBACR,KAAK,MAAM,GAAG,IAAI,IAAI,CAAC,aAAa,EAAE;oBACpC,EAAE,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC;oBACzC,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,GAAG,CAAC,IAAI,CAAC,WAAW,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,KAAK,EAAE,IAAI,CAAC,WAAoB,CAAC,EAC9D,GAAG,EAAE,CAAC,6BACI,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,WAAW,CAAC,IAAI;wBAC1C,WAAW,IAAI,CAAC,SAAS,CAAE,EAAiB,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;iBAChE;aACF;YAED,IAAI,EAA2C,CAAC;YAChD,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBACrC,kDAAkD;gBAClD,EAAE,GAAG,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;gBAC/B,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,GAAG,CAAC,IAAI,CAAC,WAAW,CAAC,EAAE,CAAC,KAAK,EAAE,IAAI,CAAC,WAAoB,CAAC,EACzD,GAAG,EAAE,CAAC,6BACI,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,WAAW,CAAC,IAAI;oBAC1C,WAAW,IAAI,CAAC,SAAS,CAAE,EAAiB,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;aAChE;iBAAM;gBACL,iDAAiD;gBACjD,EAAE,GAAG,EAAE,CAAC;gBACR,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,WAAsC,CAAC;gBAC/D,KAAK,MAAM,GAAG,IAAI,IAAI,CAAC,aAAa,EAAE;oBACpC,EAAE,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC;oBACzC,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,GAAG,CAAC,IAAI,CAAC,WAAW,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,CAAC,EAC1D,GAAG,EAAE,CAAC,6BACI,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,WAAW,CAAC,IAAI;wBAC1C,WACM,IAAI,CAAC,SAAS,CACT,EAAmC,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;iBACvE;aACF;YAED,OAAO,EAAC,IAAI,EAAE,KAAK,EAAE,EAAC,EAAE,EAAE,EAAE,EAAC,EAAC,CAAC;SAChC;IACH,CAAC;CACF;AAED;;;;;;GAMG;AACH,MAAM,OAAO,kBAAmB,SAAQ,OAA0B;IAChE,YAAqB,IAAqB;QACxC,KAAK,EAAE,CAAC;QADW,SAAI,GAAJ,IAAI,CAAiB;QAExC,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,IAAI,CAAC,SAAS,GAAG,CAAC,IAAI,MAAM,CAAC,SAAS,CAAC,IAAI,CAAC,SAAS,CAAC,EACtD,GAAG,EAAE,CACD,iDAAiD,IAAI,CAAC,SAAS,EAAE,CAAC,CAAC;QAC3E,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,IAAI,CAAC,UAAU,GAAG,CAAC,IAAI,MAAM,CAAC,SAAS,CAAC,IAAI,CAAC,UAAU,CAAC,EACxD,GAAG,EAAE,CACD,gDAAgD,IAAI,CAAC,UAAU,EAAE,CAAC,CAAC;QAC3E,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,UAAU,CAAC;IAC9B,CAAC;IAED,KAAK,CAAC,QAAQ;QACZ,OAAO,IAAI,mBAAmB,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;IAC5C,CAAC;CACF;AAED,2EAA2E;AAC3E,4EAA4E;AAC5E,MAAM,OAAO,iCAAkC,SAC3C,OAAiD;IAEnD,YAAqB,IAAqB;QACxC,KAAK,EAAE,CAAC;QADW,SAAI,GAAJ,IAAI,CAAiB;QAExC,IAAI,CAAC,EAAE,GAAG,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;IACzC,CAAC;IAED,KAAK,CAAC,QAAQ;QAEZ,MAAM,EAAE,GAAG,MAAM,IAAI,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC;QACpC,OAAO,IAAI,kCAAkC,CAAC,EAAE,CAAC,CAAC;IACpD,CAAC;CACF;AAED,MAAM,kCAAmC,SACrC,YAAsD;IACxD,YAA6B,EAAmC;QAC9D,KAAK,EAAE,CAAC;QADmB,OAAE,GAAF,EAAE,CAAiC;IAEhE,CAAC;IAED,KAAK,CAAC,IAAI;QAER,MAAM,MAAM,GAAG,MAAM,IAAI,CAAC,EAAE,CAAC,IAAI,EAAE,CAAC;QACpC,OAAO;YACL,IAAI,EAAE,MAAM,CAAC,IAAI;YACjB,KAAK,EAAE,MAAM,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,EAAE,MAAM,CAAC,KAAK,CAAC,EAAE,CAAC;SACxE,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\nimport * as tfc from '@tensorflow/tfjs-core';\n\nimport {Shape} from '../keras_format/common';\nimport {TensorOrArrayOrMap} from '../types';\n\nimport {Dataset, LazyIterator} from './dataset_stub';\nimport {FitDatasetElement} from './training_dataset';\n\nexport interface FakeDatasetArgs {\n  /**\n   * The shape(s) of the features of a single example.\n   *\n   * Use an object mapping name to shape, if more than one feature tensors\n   * are required.\n   */\n  xShape: Shape|{[name: string]: Shape};\n\n  /**\n   * The shape of the target(s) of a single exapmle.\n   */\n  yShape: Shape|{[name: string]: Shape};\n\n  /**\n   * A function that generates preset sequence of X tensors.\n   *\n   * This function is invoked each time a new iterator is created.\n   */\n  xTensorsFunc?: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};\n\n  /**\n   * A function that generates preset sequence of Y tensors.\n   *\n   * This function is invoked each time a new iterator is created.\n   */\n  yTensorsFunc?: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};\n\n  /**\n   * The size of each batch generated by the iterator.\n   */\n  batchSize: number;\n\n  /**\n   * The number of batches an iterator generates before declaring done to be\n   * true.\n   */\n  numBatches: number;\n}\n\nfunction mergeBatchSizeAndShape(\n    batchSize: number, shape: Shape|{[name: string]: Shape}): Shape|\n    {[name: string]: Shape} {\n  if (Array.isArray(shape)) {\n    return [batchSize].concat(shape);\n  } else {\n    const output: {[name: string]: Shape} = {};\n    for (const name in shape) {\n      output[name] = [batchSize].concat(shape[name]);\n    }\n    return output;\n  }\n}\n\nfunction generateRandomTensorContainer(shape: Shape|{[name: string]: Shape}):\n    tfc.Tensor|{[name: string]: tfc.Tensor} {\n  let output: tfc.Tensor|{[name: string]: tfc.Tensor};\n  if (Array.isArray(shape)) {\n    output = tfc.randomNormal(shape);\n  } else {\n    output = {};\n    for (const name in shape) {\n      output[name] = tfc.randomNormal(shape[name]);\n    }\n  }\n  return output;\n}\n\nclass FakeNumericIterator extends LazyIterator<FitDatasetElement> {\n  private xBatchShape: Shape|{[name: string]: Shape};\n  private yBatchShape: Shape|{[name: string]: Shape};\n  private numBatches: number;\n  private batchCount: number;\n  private xTensorsFunc: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};\n  private yTensorsFunc: () => tfc.Tensor[] | {[name: string]: tfc.Tensor[]};\n  private xTensorValues: tfc.Tensor[]|{[name: string]: tfc.Tensor[]};\n  private yTensorValues: tfc.Tensor[]|{[name: string]: tfc.Tensor[]};\n  private tensorIndex = 0;\n\n  constructor(args: FakeDatasetArgs) {\n    super();\n    this.xBatchShape = mergeBatchSizeAndShape(args.batchSize, args.xShape);\n    this.yBatchShape = mergeBatchSizeAndShape(args.batchSize, args.yShape);\n    this.numBatches = args.numBatches;\n    this.batchCount = 0;\n    this.xTensorsFunc = args.xTensorsFunc;\n    this.yTensorsFunc = args.yTensorsFunc;\n\n    // Sanity check on the preset tensors.\n    tfc.util.assert(\n        this.xTensorsFunc == null && this.yTensorsFunc == null ||\n            this.xTensorsFunc != null && this.yTensorsFunc != null,\n        () => 'presetXTensors and presetYTensors must be both null/undefined ' +\n            'or both set.');\n  }\n\n  async next(): Promise<IteratorResult<FitDatasetElement>> {\n    const done = ++this.batchCount > this.numBatches;\n    if (done) {\n      return {done, value: null};\n    }\n    if (this.xTensorsFunc == null) {\n      // Generate data randomly.\n      return {\n        done,\n        value: done ? null : {\n          xs: generateRandomTensorContainer(this.xBatchShape),\n          ys: generateRandomTensorContainer(this.yBatchShape)\n        }\n      };\n    } else {\n      // Use preset tensors.\n      if ((this.batchCount - 1) % this.numBatches === 0) {\n        this.xTensorValues = this.xTensorsFunc();\n        this.yTensorValues = this.yTensorsFunc();\n        this.tensorIndex = 0;\n      }\n      const index = this.tensorIndex++;\n\n      let xs: tfc.Tensor|{[name: string]: tfc.Tensor};\n      if (Array.isArray(this.xTensorValues)) {\n        xs = this.xTensorValues[index];\n        tfc.util.assert(\n            tfc.util.arraysEqual(xs.shape, this.xBatchShape as Shape),\n            () => `Shape mismatch: expected: ${\n                      JSON.stringify(this.xBatchShape)}; ` +\n                `actual: ${JSON.stringify((xs as tfc.Tensor).shape)}`);\n      } else {\n        xs = {};\n        for (const key in this.xTensorValues) {\n          xs[key] = this.xTensorValues[key][index];\n          tfc.util.assert(\n              tfc.util.arraysEqual(xs[key].shape, this.xBatchShape as Shape),\n              () => `Shape mismatch: expected: ${\n                        JSON.stringify(this.xBatchShape)}; ` +\n                  `actual: ${JSON.stringify((xs as tfc.Tensor).shape)}`);\n        }\n      }\n\n      let ys: tfc.Tensor|{[name: string]: tfc.Tensor};\n      if (Array.isArray(this.yTensorValues)) {\n        // Get preset ys tensors for single-output models.\n        ys = this.yTensorValues[index];\n        tfc.util.assert(\n            tfc.util.arraysEqual(ys.shape, this.yBatchShape as Shape),\n            () => `Shape mismatch: expected: ${\n                      JSON.stringify(this.yBatchShape)}; ` +\n                `actual: ${JSON.stringify((ys as tfc.Tensor).shape)}`);\n      } else {\n        // Get preset ys tensors for multi-output models.\n        ys = {};\n        this.yBatchShape = this.yBatchShape as {[name: string]: Shape};\n        for (const key in this.yTensorValues) {\n          ys[key] = this.yTensorValues[key][index];\n          tfc.util.assert(\n              tfc.util.arraysEqual(ys[key].shape, this.yBatchShape[key]),\n              () => `Shape mismatch: expected: ${\n                        JSON.stringify(this.yBatchShape)}; ` +\n                  `actual: ${\n                        JSON.stringify(\n                            (ys as {[name: string]: tfc.Tensor})[key].shape)}`);\n        }\n      }\n\n      return {done, value: {xs, ys}};\n    }\n  }\n}\n\n/**\n * A fake dataset with configurable feature and target shapes.\n *\n * The batch size and # of batches are also configurable.\n *\n * The iterator from the dataset always generate random-normal float32 values.\n */\nexport class FakeNumericDataset extends Dataset<FitDatasetElement> {\n  constructor(readonly args: FakeDatasetArgs) {\n    super();\n    tfc.util.assert(\n        args.batchSize > 0 && Number.isInteger(args.batchSize),\n        () =>\n            `batchSize must be a positive integer, but got ${args.batchSize}`);\n    tfc.util.assert(\n        args.numBatches > 0 && Number.isInteger(args.numBatches),\n        () =>\n            `numBatches must be positive integer, but got ${args.numBatches}`);\n    this.size = args.numBatches;\n  }\n\n  async iterator(): Promise<LazyIterator<FitDatasetElement>> {\n    return new FakeNumericIterator(this.args);\n  }\n}\n\n// We can't use Dataset.map(...) because we don't depend on tfjs-data here,\n// so we manually transform the above {xs, ys} dataset to the [xs, ys] form.\nexport class FakeNumericDatasetLegacyArrayForm extends\n    Dataset<[TensorOrArrayOrMap, TensorOrArrayOrMap]> {\n  ds: FakeNumericDataset;\n  constructor(readonly args: FakeDatasetArgs) {\n    super();\n    this.ds = new FakeNumericDataset(args);\n  }\n\n  async iterator():\n      Promise<LazyIterator<[TensorOrArrayOrMap, TensorOrArrayOrMap]>> {\n    const it = await this.ds.iterator();\n    return new FakeNumericIteratorLegacyArrayForm(it);\n  }\n}\n\nclass FakeNumericIteratorLegacyArrayForm extends\n    LazyIterator<[TensorOrArrayOrMap, TensorOrArrayOrMap]> {\n  constructor(private readonly it: LazyIterator<FitDatasetElement>) {\n    super();\n  }\n\n  async next():\n      Promise<IteratorResult<[TensorOrArrayOrMap, TensorOrArrayOrMap]>> {\n    const result = await this.it.next();\n    return {\n      done: result.done,\n      value: result.value == null ? null : [result.value.xs, result.value.ys]\n    };\n  }\n}\n"]} |
\ | No newline at end of file |