UNPKG

46.3 kBJavaScriptView Raw
1/**
2 * @license
3 * Copyright 2020 Google LLC. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * =============================================================================
16 */
17import { concat, keep, reshape, scalar, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core';
18import { assertShapesMatchAllowUndefinedSize, inferElementShape, mergeElementShape } from './tensor_utils';
19/**
20 * TensorList stores a container of `tf.Tensor` objects, which are accessible
21 * via tensors field.
22 *
23 * In order to get a copy of the underlying list, use the copy method:
24 * ```
25 * TensorList b = a.copy();
26 * b.tensors().pushBack(t); // This does not modify a.tensors().
27 * ```
28 *
29 * Note that this is not a deep copy: the memory locations of the underlying
30 * tensors will still point to the same locations of the corresponding tensors
31 * in the original.
32 */
33export class TensorList {
34 /**
35 *
36 * @param tensors list of tensors
37 * @param elementShape shape of each tensor, this can be a single number (any
38 * shape is allowed) or partial shape (dim = -1).
39 * @param elementDtype data type of each tensor
40 * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
41 * meaning that the size of `tensors` is unbounded.
42 */
43 constructor(tensors, elementShape, elementDtype, maxNumElements = -1) {
44 this.tensors = tensors;
45 this.elementShape = elementShape;
46 this.elementDtype = elementDtype;
47 if (tensors != null) {
48 tensors.forEach(tensor => {
49 if (elementDtype !== tensor.dtype) {
50 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor.dtype}`);
51 }
52 assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
53 keep(tensor);
54 });
55 }
56 this.idTensor = scalar(0);
57 this.maxNumElements = maxNumElements;
58 keep(this.idTensor);
59 }
60 get id() {
61 return this.idTensor.id;
62 }
63 /**
64 * Get a new TensorList containing a copy of the underlying tensor container.
65 */
66 copy() {
67 return new TensorList([...this.tensors], this.elementShape, this.elementDtype);
68 }
69 /**
70 * Dispose the tensors and idTensor and clear the tensor list.
71 */
72 clearAndClose(keepIds) {
73 this.tensors.forEach(tensor => {
74 if (keepIds == null || !keepIds.has(tensor.id)) {
75 tensor.dispose();
76 }
77 });
78 this.tensors.length = 0;
79 this.idTensor.dispose();
80 }
81 /**
82 * The size of the tensors in the tensor list.
83 */
84 size() {
85 return this.tensors.length;
86 }
87 /**
88 * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
89 * tf.Tensor.
90 * @param elementShape shape of each tensor
91 * @param elementDtype data type of each tensor
92 * @param numElements the number of elements to stack
93 */
94 stack(elementShape, elementDtype, numElements = -1) {
95 if (elementDtype !== this.elementDtype) {
96 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
97 }
98 if (numElements !== -1 && this.tensors.length !== numElements) {
99 throw new Error(`Operation expected a list with ${numElements} elements but got a list with ${this.tensors.length} elements.`);
100 }
101 assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
102 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
103 return tidy(() => {
104 const reshapedTensors = this.tensors.map(tensor => reshape(tensor, outputElementShape));
105 return stack(reshapedTensors, 0);
106 });
107 }
108 /**
109 * Pop a tensor from the end of the list.
110 * @param elementShape shape of the tensor
111 * @param elementDtype data type of the tensor
112 */
113 popBack(elementShape, elementDtype) {
114 if (elementDtype !== this.elementDtype) {
115 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
116 }
117 if (this.size() === 0) {
118 throw new Error('Trying to pop from an empty list.');
119 }
120 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
121 const tensor = this.tensors.pop();
122 assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
123 return reshape(tensor, outputElementShape);
124 }
125 /**
126 * Push a tensor to the end of the list.
127 * @param tensor Tensor to be pushed.
128 */
129 pushBack(tensor) {
130 if (tensor.dtype !== this.elementDtype) {
131 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
132 }
133 assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
134 if (this.maxNumElements === this.size()) {
135 throw new Error(`Trying to push element into a full list.`);
136 }
137 keep(tensor);
138 this.tensors.push(tensor);
139 }
140 /**
141 * Update the size of the list.
142 * @param size the new size of the list.
143 */
144 resize(size) {
145 if (size < 0) {
146 throw new Error(`TensorListResize expects size to be non-negative. Got: ${size}`);
147 }
148 if (this.maxNumElements !== -1 && size > this.maxNumElements) {
149 throw new Error(`TensorListResize input size ${size} is greater maxNumElement ${this.maxNumElements}.`);
150 }
151 this.tensors.length = size;
152 }
153 /**
154 * Retrieve the element at the provided index
155 * @param elementShape shape of the tensor
156 * @param elementDtype dtype of the tensor
157 * @param elementIndex index of the tensor
158 */
159 getItem(elementIndex, elementShape, elementDtype) {
160 if (elementDtype !== this.elementDtype) {
161 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
162 }
163 if (elementIndex < 0 || elementIndex > this.tensors.length) {
164 throw new Error(`Trying to access element ${elementIndex} in a list with ${this.tensors.length} elements.`);
165 }
166 if (this.tensors[elementIndex] == null) {
167 throw new Error(`element at index ${elementIndex} is null.`);
168 }
169 assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
170 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
171 return reshape(this.tensors[elementIndex], outputElementShape);
172 }
173 /**
174 * Set the tensor at the index
175 * @param elementIndex index of the tensor
176 * @param tensor the tensor to be inserted into the list
177 */
178 setItem(elementIndex, tensor) {
179 if (tensor.dtype !== this.elementDtype) {
180 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${this.elementDtype}`);
181 }
182 if (elementIndex < 0 ||
183 this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
184 throw new Error(`Trying to set element ${elementIndex} in a list with max ${this.maxNumElements} elements.`);
185 }
186 assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
187 keep(tensor);
188 this.tensors[elementIndex] = tensor;
189 }
190 /**
191 * Return selected values in the TensorList as a stacked Tensor. All of
192 * selected values must have been written and their shapes must all match.
193 * @param indices indices of tensors to gather
194 * @param elementDtype output tensor dtype
195 * @param elementShape output tensor element shape
196 */
197 gather(indices, elementDtype, elementShape) {
198 if (elementDtype !== this.elementDtype) {
199 throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${this.elementDtype}`);
200 }
201 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
202 // When indices is greater than the size of the list, indices beyond the
203 // size of the list are ignored.
204 indices = indices.slice(0, this.size());
205 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
206 if (indices.length === 0) {
207 return tensor([], [0].concat(outputElementShape));
208 }
209 return tidy(() => {
210 const tensors = indices.map(i => reshape(this.tensors[i], outputElementShape));
211 return stack(tensors, 0);
212 });
213 }
214 /**
215 * Return the values in the TensorList as a concatenated Tensor.
216 * @param elementDtype output tensor dtype
217 * @param elementShape output tensor element shape
218 */
219 concat(elementDtype, elementShape) {
220 if (!!elementDtype && elementDtype !== this.elementDtype) {
221 throw new Error(`TensorList dtype is ${this.elementDtype} but concat requested dtype ${elementDtype}`);
222 }
223 assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
224 const outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
225 if (this.size() === 0) {
226 return tensor([], [0].concat(outputElementShape));
227 }
228 return tidy(() => {
229 const tensors = this.tensors.map(t => reshape(t, outputElementShape));
230 return concat(tensors, 0);
231 });
232 }
233}
234/**
235 * Creates a TensorList which, when stacked, has the value of tensor.
236 * @param tensor from tensor
237 * @param elementShape output tensor element shape
238 */
239export function fromTensor(tensor, elementShape, elementDtype) {
240 const dtype = tensor.dtype;
241 if (tensor.shape.length < 1) {
242 throw new Error(`Tensor must be at least a vector, but saw shape: ${tensor.shape}`);
243 }
244 if (tensor.dtype !== elementDtype) {
245 throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${elementDtype}`);
246 }
247 const tensorElementShape = tensor.shape.slice(1);
248 assertShapesMatchAllowUndefinedSize(tensorElementShape, elementShape, 'TensorList shape mismatch: ');
249 const tensorList = unstack(tensor);
250 return new TensorList(tensorList, elementShape, dtype);
251}
252/**
253 * Return a TensorList of the given size with empty elements.
254 * @param elementShape the shape of the future elements of the list
255 * @param elementDtype the desired type of elements in the list
256 * @param numElements the number of elements to reserve
257 */
258export function reserve(elementShape, elementDtype, numElements) {
259 return new TensorList([], elementShape, elementDtype, numElements);
260}
261/**
262 * Put tensors at specific indices of a stacked tensor into a TensorList.
263 * @param indices list of indices on how to scatter the tensor.
264 * @param tensor input tensor.
265 * @param elementShape the shape of the future elements of the list
266 * @param numElements the number of elements to scatter
267 */
268export function scatter(tensor, indices, elementShape, numElements) {
269 if (indices.length !== tensor.shape[0]) {
270 throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
271 }
272 const maxIndex = Math.max(...indices);
273 if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
274 throw new Error(`Max index must be < array size (${maxIndex} vs. ${numElements})`);
275 }
276 const list = new TensorList([], elementShape, tensor.dtype, numElements);
277 const tensors = unstack(tensor, 0);
278 indices.forEach((value, index) => {
279 list.setItem(value, tensors[index]);
280 });
281 return list;
282}
283/**
284 * Split the values of a Tensor into a TensorList.
285 * @param length the lengths to use when splitting value along
286 * its first dimension.
287 * @param tensor the tensor to split.
288 * @param elementShape the shape of the future elements of the list
289 */
290export function split(tensor, length, elementShape) {
291 let totalLength = 0;
292 const cumulativeLengths = length.map(len => {
293 totalLength += len;
294 return totalLength;
295 });
296 if (totalLength !== tensor.shape[0]) {
297 throw new Error(`Expected sum of lengths to be equal to
298 tensor.shape[0], but sum of lengths is
299 ${totalLength}, and tensor's shape is: ${tensor.shape}`);
300 }
301 const shapeWithoutFirstDim = tensor.shape.slice(1);
302 const outputElementShape = mergeElementShape(shapeWithoutFirstDim, elementShape);
303 const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
304 const tensors = tidy(() => {
305 const tensors = [];
306 tensor = reshape(tensor, [1, totalLength, elementPerRow]);
307 for (let i = 0; i < length.length; ++i) {
308 const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
309 const indices = [0, previousLength, 0];
310 const sizes = [1, length[i], elementPerRow];
311 tensors[i] = reshape(slice(tensor, indices, sizes), outputElementShape);
312 }
313 tensor.dispose();
314 return tensors;
315 });
316 const list = new TensorList([], elementShape, tensor.dtype, length.length);
317 for (let i = 0; i < tensors.length; i++) {
318 list.setItem(i, tensors[i]);
319 }
320 return list;
321}
322//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tensor_list.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/tensor_list.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAY,IAAI,EAAE,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,EAAE,IAAI,EAAE,OAAO,EAAC,MAAM,uBAAuB,CAAC;AAE3H,OAAO,EAAC,mCAAmC,EAAE,iBAAiB,EAAE,iBAAiB,EAAC,MAAM,gBAAgB,CAAC;AAEzG;;;;;;;;;;;;;GAaG;AAEH,MAAM,OAAO,UAAU;IAOrB;;;;;;;;OAQG;IACH,YACa,OAAiB,EAAW,YAA6B,EACzD,YAAsB,EAAE,cAAc,GAAG,CAAC,CAAC;QAD3C,YAAO,GAAP,OAAO,CAAU;QAAW,iBAAY,GAAZ,YAAY,CAAiB;QACzD,iBAAY,GAAZ,YAAY,CAAU;QACjC,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;gBACvB,IAAI,YAAY,KAAK,MAAM,CAAC,KAAK,EAAE;oBACjC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;iBACxD;gBACD,mCAAmC,CAC/B,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,6BAA6B,CAAC,CAAC;gBAE/D,IAAI,CAAC,MAAM,CAAC,CAAC;YACf,CAAC,CAAC,CAAC;SACJ;QACD,IAAI,CAAC,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QAC1B,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IACtB,CAAC;IA9BD,IAAI,EAAE;QACJ,OAAO,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC;IAC1B,CAAC;IA8BD;;OAEG;IACH,IAAI;QACF,OAAO,IAAI,UAAU,CACjB,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,YAAY,CAAC,CAAC;IAC/D,CAAC;IAED;;OAEG;IACH,aAAa,CAAC,OAAqB;QACjC,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC5B,IAAI,OAAO,IAAI,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE;gBAC9C,MAAM,CAAC,OAAO,EAAE,CAAC;aAClB;QACH,CAAC,CAAC,CAAC;QACH,IAAI,CAAC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC;QACxB,IAAI,CAAC,QAAQ,CAAC,OAAO,EAAE,CAAC;IAC1B,CAAC;IACD;;OAEG;IACH,IAAI;QACF,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC;IAC7B,CAAC;IAED;;;;;;OAMG;IACH,KAAK,CAAC,YAAsB,EAAE,YAAsB,EAAE,WAAW,GAAG,CAAC,CAAC;QAEpE,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QACD,IAAI,WAAW,KAAK,CAAC,CAAC,IAAI,IAAI,CAAC,OAAO,CAAC,MAAM,KAAK,WAAW,EAAE;YAC7D,MAAM,IAAI,KAAK,CAAC,kCACZ,WAAW,iCACX,IAAI,CAAC,OAAO,CAAC,MAAM,YAAY,CAAC,CAAC;SACtC;QACD,mCAAmC,CAC/B,YAAY,EAAE,IAAI,CAAC,YAAY,EAAE,6BAA6B,CAAC,CAAC;QACpE,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,eAAe,GACjB,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC,CAAC;YACpE,OAAO,KAAK,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;OAIG;IACH,OAAO,CAAC,YAAsB,EAAE,YAAsB;QACpD,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC,EAAE;YACrB,MAAM,IAAI,KAAK,CAAC,mCAAmC,CAAC,CAAC;SACtD;QACD,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,MAAM,MAAM,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,EAAE,CAAC;QAElC,mCAAmC,CAC/B,MAAM,CAAC,KAAK,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;QAE/D,OAAO,OAAO,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;IAC7C,CAAC;IAED;;;OAGG;IACH,QAAQ,CAAC,MAAc;QACrB,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,MAAM,CAAC,KAAK,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,mCAAmC,CAC/B,MAAM,CAAC,KAAK,EAAE,IAAI,CAAC,YAAY,EAAE,6BAA6B,CAAC,CAAC;QAEpE,IAAI,IAAI,CAAC,cAAc,KAAK,IAAI,CAAC,IAAI,EAAE,EAAE;YACvC,MAAM,IAAI,KAAK,CAAC,0CAA0C,CAAC,CAAC;SAC7D;QACD,IAAI,CAAC,MAAM,CAAC,CAAC;QACb,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IAC5B,CAAC;IAED;;;OAGG;IACH,MAAM,CAAC,IAAY;QACjB,IAAI,IAAI,GAAG,CAAC,EAAE;YACZ,MAAM,IAAI,KAAK,CACX,0DAA0D,IAAI,EAAE,CAAC,CAAC;SACvE;QAED,IAAI,IAAI,CAAC,cAAc,KAAK,CAAC,CAAC,IAAI,IAAI,GAAG,IAAI,CAAC,cAAc,EAAE;YAC5D,MAAM,IAAI,KAAK,CAAC,+BACZ,IAAI,6BAA6B,IAAI,CAAC,cAAc,GAAG,CAAC,CAAC;SAC9D;QACD,IAAI,CAAC,OAAO,CAAC,MAAM,GAAG,IAAI,CAAC;IAC7B,CAAC;IAED;;;;;OAKG;IACH,OAAO,CAAC,YAAoB,EAAE,YAAsB,EAAE,YAAsB;QAE1E,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QACD,IAAI,YAAY,GAAG,CAAC,IAAI,YAAY,GAAG,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;YAC1D,MAAM,IAAI,KAAK,CAAC,4BACZ,YAAY,mBAAmB,IAAI,CAAC,OAAO,CAAC,MAAM,YAAY,CAAC,CAAC;SACrE;QAED,IAAI,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,IAAI,IAAI,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,oBAAoB,YAAY,WAAW,CAAC,CAAC;SAC9D;QAED,mCAAmC,CAC/B,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,KAAK,EAAE,YAAY,EAC9C,6BAA6B,CAAC,CAAC;QACnC,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,OAAO,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE,kBAAkB,CAAC,CAAC;IACjE,CAAC;IAED;;;;OAIG;IACH,OAAO,CAAC,YAAoB,EAAE,MAAc;QAC1C,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,MAAM,CAAC,KAAK,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,IAAI,YAAY,GAAG,CAAC;YAChB,IAAI,CAAC,cAAc,KAAK,CAAC,CAAC,IAAI,YAAY,IAAI,IAAI,CAAC,cAAc,EAAE;YACrE,MAAM,IAAI,KAAK,CAAC,yBACZ,YAAY,uBAAuB,IAAI,CAAC,cAAc,YAAY,CAAC,CAAC;SACzE;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,6BAA6B,CAAC,CAAC;QACpE,IAAI,CAAC,MAAM,CAAC,CAAC;QACb,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,GAAG,MAAM,CAAC;IACtC,CAAC;IAED;;;;;;OAMG;IACH,MAAM,CAAC,OAAiB,EAAE,YAAsB,EAAE,YAAsB;QAEtE,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,mCACZ,YAAY,uBAAuB,IAAI,CAAC,YAAY,EAAE,CAAC,CAAC;SAC7D;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;QAEpE,wEAAwE;QACxE,gCAAgC;QAChC,OAAO,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;QACxC,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QACrE,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,kBAAkB,CAAC,CAAC,CAAC;SACnD;QAED,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,OAAO,GACT,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,kBAAkB,CAAC,CAAC,CAAC;YACnE,OAAO,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;QAC3B,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;OAIG;IACH,MAAM,CAAC,YAAsB,EAAE,YAAsB;QACnD,IAAI,CAAC,CAAC,YAAY,IAAI,YAAY,KAAK,IAAI,CAAC,YAAY,EAAE;YACxD,MAAM,IAAI,KAAK,CAAC,uBACZ,IAAI,CAAC,YAAY,+BAA+B,YAAY,EAAE,CAAC,CAAC;SACrE;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;QACpE,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,OAAO,EAAE,YAAY,CAAC,CAAC;QAErE,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC,EAAE;YACrB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,kBAAkB,CAAC,CAAC,CAAC;SACnD;QACD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,CAAC,EAAE,kBAAkB,CAAC,CAAC,CAAC;YACtE,OAAO,MAAM,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;QAC5B,CAAC,CAAC,CAAC;IACL,CAAC;CACF;AAED;;;;GAIG;AACH,MAAM,UAAU,UAAU,CACtB,MAAc,EAAE,YAAsB,EAAE,YAAsB;IAChE,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC;IAC3B,IAAI,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QAC3B,MAAM,IAAI,KAAK,CACX,oDAAoD,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;KACzE;IACD,IAAI,MAAM,CAAC,KAAK,KAAK,YAAY,EAAE;QACjC,MAAM,IAAI,KAAK,CAAC,mCACZ,MAAM,CAAC,KAAK,uBAAuB,YAAY,EAAE,CAAC,CAAC;KACxD;IACD,MAAM,kBAAkB,GAAG,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACjD,mCAAmC,CAC/B,kBAAkB,EAAE,YAAY,EAAE,6BAA6B,CAAC,CAAC;IACrE,MAAM,UAAU,GAAa,OAAO,CAAC,MAAM,CAAC,CAAC;IAC7C,OAAO,IAAI,UAAU,CAAC,UAAU,EAAE,YAAY,EAAE,KAAK,CAAC,CAAC;AACzD,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,OAAO,CACnB,YAAsB,EAAE,YAAsB,EAAE,WAAmB;IACrE,OAAO,IAAI,UAAU,CAAC,EAAE,EAAE,YAAY,EAAE,YAAY,EAAE,WAAW,CAAC,CAAC;AACrE,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,OAAO,CACnB,MAAc,EAAE,OAAiB,EAAE,YAAsB,EACzD,WAAoB;IACtB,IAAI,OAAO,CAAC,MAAM,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;QACtC,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,MAAM,QAAQ,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;KAC9C;IAED,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,OAAO,CAAC,CAAC;IAEtC,IAAI,WAAW,IAAI,IAAI,IAAI,WAAW,KAAK,CAAC,CAAC,IAAI,QAAQ,IAAI,WAAW,EAAE;QACxE,MAAM,IAAI,KAAK,CACX,mCAAmC,QAAQ,SAAS,WAAW,GAAG,CAAC,CAAC;KACzE;IAED,MAAM,IAAI,GAAG,IAAI,UAAU,CAAC,EAAE,EAAE,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,WAAW,CAAC,CAAC;IACzE,MAAM,OAAO,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;IACnC,OAAO,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,EAAE;QAC/B,IAAI,CAAC,OAAO,CAAC,KAAK,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IACH,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,KAAK,CACjB,MAAc,EAAE,MAAgB,EAAE,YAAsB;IAC1D,IAAI,WAAW,GAAG,CAAC,CAAC;IACpB,MAAM,iBAAiB,GAAG,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QACzC,WAAW,IAAI,GAAG,CAAC;QACnB,OAAO,WAAW,CAAC;IACrB,CAAC,CAAC,CAAC;IAEH,IAAI,WAAW,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;QACnC,MAAM,IAAI,KAAK,CAAC;;UAEV,WAAW,4BAA4B,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;KAC9D;IAED,MAAM,oBAAoB,GAAG,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACnD,MAAM,kBAAkB,GACpB,iBAAiB,CAAC,oBAAoB,EAAE,YAAY,CAAC,CAAC;IAC1D,MAAM,aAAa,GAAG,WAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,GAAG,WAAW,CAAC;IACxE,MAAM,OAAO,GAAa,IAAI,CAAC,GAAG,EAAE;QAClC,MAAM,OAAO,GAAG,EAAE,CAAC;QACnB,MAAM,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC,CAAC;QAC1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACtC,MAAM,cAAc,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,iBAAiB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAChE,MAAM,OAAO,GAAG,CAAC,CAAC,EAAE,cAAc,EAAE,CAAC,CAAC,CAAC;YACvC,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC;YAC5C,OAAO,CAAC,CAAC,CAAC,GAAG,OAAO,CAChB,KAAK,CAAC,MAAM,EAAE,OAAO,EAAE,KAAK,CAAC,EAAE,kBAA8B,CAAC,CAAC;SACpE;QACD,MAAM,CAAC,OAAO,EAAE,CAAC;QACjB,OAAO,OAAO,CAAC;IACjB,CAAC,CAAC,CAAC;IAEH,MAAM,IAAI,GAAG,IAAI,UAAU,CAAC,EAAE,EAAE,YAAY,EAAE,MAAM,CAAC,KAAK,EAAE,MAAM,CAAC,MAAM,CAAC,CAAC;IAE3E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;QACvC,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;KAC7B;IACD,OAAO,IAAI,CAAC;AACd,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';\n\nimport {assertShapesMatchAllowUndefinedSize, inferElementShape, mergeElementShape} from './tensor_utils';\n\n/**\n * TensorList stores a container of `tf.Tensor` objects, which are accessible\n * via tensors field.\n *\n * In order to get a copy of the underlying list, use the copy method:\n * ```\n *    TensorList b = a.copy();\n *    b.tensors().pushBack(t);  // This does not modify a.tensors().\n * ```\n *\n * Note that this is not a deep copy: the memory locations of the underlying\n * tensors will still point to the same locations of the corresponding tensors\n * in the original.\n */\n\nexport class TensorList {\n  readonly idTensor: Tensor;\n  maxNumElements: number;\n\n  get id() {\n    return this.idTensor.id;\n  }\n  /**\n   *\n   * @param tensors list of tensors\n   * @param elementShape shape of each tensor, this can be a single number (any\n   * shape is allowed) or partial shape (dim = -1).\n   * @param elementDtype data type of each tensor\n   * @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1\n   *   meaning that the size of `tensors` is unbounded.\n   */\n  constructor(\n      readonly tensors: Tensor[], readonly elementShape: number|number[],\n      readonly elementDtype: DataType, maxNumElements = -1) {\n    if (tensors != null) {\n      tensors.forEach(tensor => {\n        if (elementDtype !== tensor.dtype) {\n          throw new Error(`Invalid data types; op elements ${\n              elementDtype}, but list elements ${tensor.dtype}`);\n        }\n        assertShapesMatchAllowUndefinedSize(\n            elementShape, tensor.shape, 'TensorList shape mismatch: ');\n\n        keep(tensor);\n      });\n    }\n    this.idTensor = scalar(0);\n    this.maxNumElements = maxNumElements;\n    keep(this.idTensor);\n  }\n\n  /**\n   * Get a new TensorList containing a copy of the underlying tensor container.\n   */\n  copy(): TensorList {\n    return new TensorList(\n        [...this.tensors], this.elementShape, this.elementDtype);\n  }\n\n  /**\n   * Dispose the tensors and idTensor and clear the tensor list.\n   */\n  clearAndClose(keepIds?: Set<number>) {\n    this.tensors.forEach(tensor => {\n      if (keepIds == null || !keepIds.has(tensor.id)) {\n        tensor.dispose();\n      }\n    });\n    this.tensors.length = 0;\n    this.idTensor.dispose();\n  }\n  /**\n   * The size of the tensors in the tensor list.\n   */\n  size() {\n    return this.tensors.length;\n  }\n\n  /**\n   * Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)\n   * tf.Tensor.\n   * @param elementShape shape of each tensor\n   * @param elementDtype data type of each tensor\n   * @param numElements the number of elements to stack\n   */\n  stack(elementShape: number[], elementDtype: DataType, numElements = -1):\n      Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n    if (numElements !== -1 && this.tensors.length !== numElements) {\n      throw new Error(`Operation expected a list with ${\n          numElements} elements but got a list with ${\n          this.tensors.length} elements.`);\n    }\n    assertShapesMatchAllowUndefinedSize(\n        elementShape, this.elementShape, 'TensorList shape mismatch: ');\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    return tidy(() => {\n      const reshapedTensors =\n          this.tensors.map(tensor => reshape(tensor, outputElementShape));\n      return stack(reshapedTensors, 0);\n    });\n  }\n\n  /**\n   * Pop a tensor from the end of the list.\n   * @param elementShape shape of the tensor\n   * @param elementDtype data type of the tensor\n   */\n  popBack(elementShape: number[], elementDtype: DataType): Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n\n    if (this.size() === 0) {\n      throw new Error('Trying to pop from an empty list.');\n    }\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    const tensor = this.tensors.pop();\n\n    assertShapesMatchAllowUndefinedSize(\n        tensor.shape, elementShape, 'TensorList shape mismatch: ');\n\n    return reshape(tensor, outputElementShape);\n  }\n\n  /**\n   * Push a tensor to the end of the list.\n   * @param tensor Tensor to be pushed.\n   */\n  pushBack(tensor: Tensor) {\n    if (tensor.dtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          tensor.dtype}, but list elements ${this.elementDtype}`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        tensor.shape, this.elementShape, 'TensorList shape mismatch: ');\n\n    if (this.maxNumElements === this.size()) {\n      throw new Error(`Trying to push element into a full list.`);\n    }\n    keep(tensor);\n    this.tensors.push(tensor);\n  }\n\n  /**\n   * Update the size of the list.\n   * @param size the new size of the list.\n   */\n  resize(size: number) {\n    if (size < 0) {\n      throw new Error(\n          `TensorListResize expects size to be non-negative. Got: ${size}`);\n    }\n\n    if (this.maxNumElements !== -1 && size > this.maxNumElements) {\n      throw new Error(`TensorListResize input size ${\n          size} is greater maxNumElement ${this.maxNumElements}.`);\n    }\n    this.tensors.length = size;\n  }\n\n  /**\n   * Retrieve the element at the provided index\n   * @param elementShape shape of the tensor\n   * @param elementDtype dtype of the tensor\n   * @param elementIndex index of the tensor\n   */\n  getItem(elementIndex: number, elementShape: number[], elementDtype: DataType):\n      Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n    if (elementIndex < 0 || elementIndex > this.tensors.length) {\n      throw new Error(`Trying to access element ${\n          elementIndex} in a list with ${this.tensors.length} elements.`);\n    }\n\n    if (this.tensors[elementIndex] == null) {\n      throw new Error(`element at index ${elementIndex} is null.`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.tensors[elementIndex].shape, elementShape,\n        'TensorList shape mismatch: ');\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    return reshape(this.tensors[elementIndex], outputElementShape);\n  }\n\n  /**\n   * Set the tensor at the index\n   * @param elementIndex index of the tensor\n   * @param tensor the tensor to be inserted into the list\n   */\n  setItem(elementIndex: number, tensor: Tensor) {\n    if (tensor.dtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          tensor.dtype}, but list elements ${this.elementDtype}`);\n    }\n\n    if (elementIndex < 0 ||\n        this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {\n      throw new Error(`Trying to set element ${\n          elementIndex} in a list with max ${this.maxNumElements} elements.`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensor.shape, 'TensorList shape mismatch: ');\n    keep(tensor);\n    this.tensors[elementIndex] = tensor;\n  }\n\n  /**\n   * Return selected values in the TensorList as a stacked Tensor. All of\n   * selected values must have been written and their shapes must all match.\n   * @param indices indices of tensors to gather\n   * @param elementDtype output tensor dtype\n   * @param elementShape output tensor element shape\n   */\n  gather(indices: number[], elementDtype: DataType, elementShape: number[]):\n      Tensor {\n    if (elementDtype !== this.elementDtype) {\n      throw new Error(`Invalid data types; op elements ${\n          elementDtype}, but list elements ${this.elementDtype}`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, elementShape, 'TensorList shape mismatch: ');\n\n    // When indices is greater than the size of the list, indices beyond the\n    // size of the list are ignored.\n    indices = indices.slice(0, this.size());\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n    if (indices.length === 0) {\n      return tensor([], [0].concat(outputElementShape));\n    }\n\n    return tidy(() => {\n      const tensors =\n          indices.map(i => reshape(this.tensors[i], outputElementShape));\n      return stack(tensors, 0);\n    });\n  }\n\n  /**\n   * Return the values in the TensorList as a concatenated Tensor.\n   * @param elementDtype output tensor dtype\n   * @param elementShape output tensor element shape\n   */\n  concat(elementDtype: DataType, elementShape: number[]): Tensor {\n    if (!!elementDtype && elementDtype !== this.elementDtype) {\n      throw new Error(`TensorList dtype is ${\n          this.elementDtype} but concat requested dtype ${elementDtype}`);\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, elementShape, 'TensorList shape mismatch: ');\n    const outputElementShape =\n        inferElementShape(this.elementShape, this.tensors, elementShape);\n\n    if (this.size() === 0) {\n      return tensor([], [0].concat(outputElementShape));\n    }\n    return tidy(() => {\n      const tensors = this.tensors.map(t => reshape(t, outputElementShape));\n      return concat(tensors, 0);\n    });\n  }\n}\n\n/**\n * Creates a TensorList which, when stacked, has the value of tensor.\n * @param tensor from tensor\n * @param elementShape output tensor element shape\n */\nexport function fromTensor(\n    tensor: Tensor, elementShape: number[], elementDtype: DataType) {\n  const dtype = tensor.dtype;\n  if (tensor.shape.length < 1) {\n    throw new Error(\n        `Tensor must be at least a vector, but saw shape: ${tensor.shape}`);\n  }\n  if (tensor.dtype !== elementDtype) {\n    throw new Error(`Invalid data types; op elements ${\n        tensor.dtype}, but list elements ${elementDtype}`);\n  }\n  const tensorElementShape = tensor.shape.slice(1);\n  assertShapesMatchAllowUndefinedSize(\n      tensorElementShape, elementShape, 'TensorList shape mismatch: ');\n  const tensorList: Tensor[] = unstack(tensor);\n  return new TensorList(tensorList, elementShape, dtype);\n}\n\n/**\n * Return a TensorList of the given size with empty elements.\n * @param elementShape the shape of the future elements of the list\n * @param elementDtype the desired type of elements in the list\n * @param numElements the number of elements to reserve\n */\nexport function reserve(\n    elementShape: number[], elementDtype: DataType, numElements: number) {\n  return new TensorList([], elementShape, elementDtype, numElements);\n}\n\n/**\n * Put tensors at specific indices of a stacked tensor into a TensorList.\n * @param indices list of indices on how to scatter the tensor.\n * @param tensor input tensor.\n * @param elementShape the shape of the future elements of the list\n * @param numElements the number of elements to scatter\n */\nexport function scatter(\n    tensor: Tensor, indices: number[], elementShape: number[],\n    numElements?: number): TensorList {\n  if (indices.length !== tensor.shape[0]) {\n    throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${\n        indices.length} vs. ${tensor.shape[0]}`);\n  }\n\n  const maxIndex = Math.max(...indices);\n\n  if (numElements != null && numElements !== -1 && maxIndex >= numElements) {\n    throw new Error(\n        `Max index must be < array size (${maxIndex}  vs. ${numElements})`);\n  }\n\n  const list = new TensorList([], elementShape, tensor.dtype, numElements);\n  const tensors = unstack(tensor, 0);\n  indices.forEach((value, index) => {\n    list.setItem(value, tensors[index]);\n  });\n  return list;\n}\n\n/**\n * Split the values of a Tensor into a TensorList.\n * @param length the lengths to use when splitting value along\n *    its first dimension.\n * @param tensor the tensor to split.\n * @param elementShape the shape of the future elements of the list\n */\nexport function split(\n    tensor: Tensor, length: number[], elementShape: number[]) {\n  let totalLength = 0;\n  const cumulativeLengths = length.map(len => {\n    totalLength += len;\n    return totalLength;\n  });\n\n  if (totalLength !== tensor.shape[0]) {\n    throw new Error(`Expected sum of lengths to be equal to\n          tensor.shape[0], but sum of lengths is\n        ${totalLength}, and tensor's shape is: ${tensor.shape}`);\n  }\n\n  const shapeWithoutFirstDim = tensor.shape.slice(1);\n  const outputElementShape =\n      mergeElementShape(shapeWithoutFirstDim, elementShape);\n  const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;\n  const tensors: Tensor[] = tidy(() => {\n    const tensors = [];\n    tensor = reshape(tensor, [1, totalLength, elementPerRow]);\n    for (let i = 0; i < length.length; ++i) {\n      const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];\n      const indices = [0, previousLength, 0];\n      const sizes = [1, length[i], elementPerRow];\n      tensors[i] = reshape(\n          slice(tensor, indices, sizes), outputElementShape as number[]);\n    }\n    tensor.dispose();\n    return tensors;\n  });\n\n  const list = new TensorList([], elementShape, tensor.dtype, length.length);\n\n  for (let i = 0; i < tensors.length; i++) {\n    list.setItem(i, tensors[i]);\n  }\n  return list;\n}\n"]}
\No newline at end of file