/**
 * @license
 * Copyright 2017 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {tensorToString} from './tensor_format';
import {ArrayMap, BackendValues, DataType, DataTypeMap, NumericDataType, Rank, ShapeMap, SingleValueMap, TensorLike, TensorLike1D, TensorLike3D, TensorLike4D, TypedArray} from './types';
import * as util from './util';
import {computeStrides, toNestedArray} from './util';

export interface TensorData<D extends DataType> {
  dataId?: DataId;
  values?: DataTypeMap[D];
}

// This interface mimics KernelBackend (in backend.ts), which would create a
// circular dependency if imported.
export interface Backend {
  read(dataId: object): Promise<BackendValues>;
  readSync(dataId: object): BackendValues;
  disposeData(dataId: object): void;
  write(dataId: object, values: BackendValues): void;
}

/**
 * A mutable object, similar to `tf.Tensor`, that allows users to set values
 * at locations before converting to an immutable `tf.Tensor`.
 *
 * See `tf.buffer` for creating a tensor buffer.
 */
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
export class TensorBuffer<R extends Rank, D extends DataType = 'float32'> {
  size: number;
  shape: ShapeMap[R];
  strides: number[];
  values: DataTypeMap[D];

  constructor(shape: ShapeMap[R], public dtype: D, values?: DataTypeMap[D]) {
    this.shape = shape.slice() as ShapeMap[R];
    this.size = util.sizeFromShape(shape);

    if (values != null) {
      const n = values.length;
      util.assert(
          n === this.size,
          () => `Length of values '${n}' does not match the size ` +
              `inferred by the shape '${this.size}'.`);
    }
    if (dtype === 'complex64') {
      throw new Error(
          `complex64 dtype TensorBuffers are not supported. Please create ` +
          `a TensorBuffer for the real and imaginary parts separately and ` +
          `call tf.complex(real, imag).`);
    }
    this.values = values || util.getArrayFromDType(dtype, this.size);
    this.strides = computeStrides(shape);
  }

  /**
   * Sets a value in the buffer at a given location.
   *
   * @param value The value to set.
   * @param locs  The location indices.
   */
  /** @doc {heading: 'Tensors', subheading: 'Creation'} */
  set(value: SingleValueMap[D], ...locs: number[]): void {
    if (locs.length === 0) {
      locs = [0];
    }
    util.assert(
        locs.length === this.rank,
        () => `The number of provided coordinates (${locs.length}) must ` +
            `match the rank (${this.rank})`);

    const index = this.locToIndex(locs);
    this.values[index] = value as number;
  }

  /**
   * Returns the value in the buffer at the provided location.
   *
   * @param locs The location indices.
   */
  /** @doc {heading: 'Tensors', subheading: 'Creation'} */
  get(...locs: number[]): SingleValueMap[D] {
    if (locs.length === 0) {
      locs = [0];
    }
    let i = 0;
    for (const loc of locs) {
      if (loc < 0 || loc >= this.shape[i]) {
        const msg = `Requested out of range element at ${locs}. ` +
            `  Buffer shape=${this.shape}`;
        throw new Error(msg);
      }
      i++;
    }
    let index = locs[locs.length - 1];
    for (let i = 0; i < locs.length - 1; ++i) {
      index += this.strides[i] * locs[i];
    }
    return this.values[index] as SingleValueMap[D];
  }

  locToIndex(locs: number[]): number {
    if (this.rank === 0) {
      return 0;
    } else if (this.rank === 1) {
      return locs[0];
    }
    let index = locs[locs.length - 1];
    for (let i = 0; i < locs.length - 1; ++i) {
      index += this.strides[i] * locs[i];
    }
    return index;
  }

  indexToLoc(index: number): number[] {
    if (this.rank === 0) {
      return [];
    } else if (this.rank === 1) {
      return [index];
    }
    const locs: number[] = new Array(this.shape.length);
    for (let i = 0; i < locs.length - 1; ++i) {
      locs[i] = Math.floor(index / this.strides[i]);
      index -= locs[i] * this.strides[i];
    }
    locs[locs.length - 1] = index;
    return locs;
  }

  get rank() {
    return this.shape.length;
  }

  /**
   * Creates an immutable `tf.Tensor` object from the buffer.
   */
  /** @doc {heading: 'Tensors', subheading: 'Creation'} */
  toTensor(): Tensor<R> {
    return Tensor.make(this.shape, {values: this.values}, this.dtype);
  }
}

export interface TensorTracker {
  registerTensor(t: Tensor, backend?: Backend): void;
  disposeTensor(t: Tensor): void;
  disposeVariable(v: Variable): void;
  write(backend: Backend, dataId: DataId, values: BackendValues): void;
  read(dataId: DataId): Promise<BackendValues>;
  readSync(dataId: DataId): BackendValues;
  registerVariable(v: Variable): void;
  nextTensorId(): number;
  nextVariableId(): number;
}

/**
 * The Tensor class calls into this handler to delegate chaining operations.
 */
export interface OpHandler {
  cast<T extends Tensor>(x: T, dtype: DataType): T;
  buffer<R extends Rank, D extends DataType>(
      shape: ShapeMap[R], dtype: D,
      values?: DataTypeMap[D]): TensorBuffer<R, D>;
  print<T extends Tensor>(x: T, verbose: boolean): void;
  reshape<R2 extends Rank>(x: Tensor, shape: ShapeMap[R2]): Tensor<R2>;
  expandDims<R2 extends Rank>(x: Tensor, axis: number): Tensor<R2>;
  cumsum<T extends Tensor>(
      x: Tensor, axis: number, exclusive: boolean, reverse: boolean): T;
  squeeze<T extends Tensor>(x: Tensor, axis?: number[]): T;
  clone<T extends Tensor>(x: T): T;
  oneHot(
      x: Tensor|TensorLike, depth: number, onValue?: number,
      offValue?: number): Tensor;
  tile<T extends Tensor>(x: T, reps: number[]): T;
  gather<T extends Tensor>(x: T, indices: Tensor|TensorLike, axis: number): T;
  matMul<T extends Tensor>(
      a: T, b: T|TensorLike, transposeA: boolean, transposeB: boolean): T;
  dot(t1: Tensor, t2: Tensor|TensorLike): Tensor;
  norm(
      x: Tensor, ord: number|'euclidean'|'fro', axis: number|number[],
      keepDims: boolean): Tensor;
  slice<R extends Rank, T extends Tensor<R>>(
      x: T, begin: number|number[], size?: number|number[]): T;
  split<T extends Tensor>(
      x: T, numOrSizeSplits: number[]|number, axis?: number): T[];
  reverse<T extends Tensor>(x: T, axis?: number|number[]): T;
  concat<T extends Tensor>(tensors: Array<T|TensorLike>, axis: number): T;
  stack<T extends Tensor>(tensors: Array<T|TensorLike>, axis: number): Tensor;
  unstack<T extends Tensor>(value: T, axis: number): Tensor[];
  pad<T extends Tensor>(
      x: T, paddings: Array<[number, number]>, constantValue: number): T;
  batchNorm<R extends Rank>(
      x: Tensor<R>, mean: Tensor<R>|Tensor1D|TensorLike,
      variance: Tensor<R>|Tensor1D|TensorLike,
      offset?: Tensor<R>|Tensor1D|TensorLike,
      scale?: Tensor<R>|Tensor1D|TensorLike,
      varianceEpsilon?: number): Tensor<R>;
  all<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
  any<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
  logSumExp<T extends Tensor>(
      x: Tensor, axis: number|number[], keepDims: boolean): T;
  sum<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
  prod<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean):
      T;
  mean<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean):
      T;
  min<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
  max<T extends Tensor>(x: Tensor, axis: number|number[], keepDims: boolean): T;
  argMin<T extends Tensor>(x: Tensor, axis: number): T;
  argMax<T extends Tensor>(x: Tensor, axis: number): T;
  add<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  addStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  atan2<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  sub<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  subStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  pow<T extends Tensor>(base: T, exp: Tensor|TensorLike): T;
  powStrict<T extends Tensor>(base: T, exp: Tensor|TensorLike): T;
  mul<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  mulStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  div<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  floorDiv<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  divStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  mod<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  modStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  minimum<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  minimumStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  maximum<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  maximumStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  squaredDifference<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  squaredDifferenceStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  transpose<T extends Tensor>(x: T, perm?: number[]): T;
  logicalNot<T extends Tensor>(x: T): T;
  logicalAnd<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  logicalOr<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  logicalXor<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  where<T extends Tensor>(condition: Tensor|TensorLike, a: T, b: T|TensorLike):
      T;
  notEqual<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  notEqualStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  less<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  lessStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  equal<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  equalStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  lessEqual<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  lessEqualStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  greater<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  greaterStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  greaterEqual<T extends Tensor>(a: Tensor, b: Tensor|TensorLike): T;
  greaterEqualStrict<T extends Tensor>(a: T, b: T|TensorLike): T;
  neg<T extends Tensor>(x: T): T;
  ceil<T extends Tensor>(x: T): T;
  floor<T extends Tensor>(x: T): T;
  sign<T extends Tensor>(x: T): T;
  isNaN<T extends Tensor>(x: T): T;
  isInf<T extends Tensor>(x: T): T;
  isFinite<T extends Tensor>(x: T): T;
  round<T extends Tensor>(x: T): T;
  exp<T extends Tensor>(x: T): T;
  expm1<T extends Tensor>(x: T): T;
  log<T extends Tensor>(x: T): T;
  log1p<T extends Tensor>(x: T): T;
  sqrt<T extends Tensor>(x: T): T;
  rsqrt<T extends Tensor>(x: T): T;
  square<T extends Tensor>(x: T): T;
  reciprocal<T extends Tensor>(x: T): T;
  abs<T extends Tensor>(x: T): T;
  clipByValue<T extends Tensor>(
      x: T, clipValueMin: number, clipValueMax: number): T;
  sigmoid<T extends Tensor>(x: T): T;
  logSigmoid<T extends Tensor>(x: T): T;
  softplus<T extends Tensor>(x: T): T;
  zerosLike<T extends Tensor>(x: T): T;
  onesLike<T extends Tensor>(x: T): T;
  sin<T extends Tensor>(x: T): T;
  cos<T extends Tensor>(x: T): T;
  tan<T extends Tensor>(x: T): T;
  asin<T extends Tensor>(x: T): T;
  acos<T extends Tensor>(x: T): T;
  atan<T extends Tensor>(x: T): T;
  sinh<T extends Tensor>(x: T): T;
  cosh<T extends Tensor>(x: T): T;
  tanh<T extends Tensor>(x: T): T;
  asinh<T extends Tensor>(x: T): T;
  acosh<T extends Tensor>(x: T): T;
  atanh<T extends Tensor>(x: T): T;
  erf<T extends Tensor>(x: T): T;
  step<T extends Tensor>(x: T, alpha: number): T;
  relu<T extends Tensor>(x: T): T;
  elu<T extends Tensor>(x: T): T;
  selu<T extends Tensor>(x: T): T;
  leakyRelu<T extends Tensor>(x: T, alpha: number): T;
  prelu<T extends Tensor>(x: T, alpha: T|TensorLike): T;
  softmax<T extends Tensor>(logits: T, dim: number): T;
  logSoftmax<T extends Tensor>(logits: T, axis: number): T;
  image: {
    resizeBilinear<T extends Tensor3D|Tensor4D>(
        images: T, size: [number, number], alignCorners: boolean): T;
    resizeNearestNeighbor<T extends Tensor3D|Tensor4D>(
        images: T, size: [number, number], alignCorners: boolean): T;
  };
  conv1d<T extends Tensor2D|Tensor3D>(
      x: T, filter: Tensor3D|TensorLike3D, stride: number,
      pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW', dilation: number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T;
  conv2d<T extends Tensor3D|Tensor4D>(
      x: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,
      pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW',
      dilations: [number, number]|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T;
  conv2dTranspose<T extends Tensor3D|Tensor4D>(
      x: T, filter: Tensor4D|TensorLike4D,
      outputShape: [number, number, number, number]|[number, number, number],
      strides: [number, number]|number, pad: 'valid'|'same'|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T;
  depthwiseConv2d<T extends Tensor3D|Tensor4D>(
      x: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,
      pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW',
      dilations: [number, number]|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T;
  separableConv2d<T extends Tensor3D|Tensor4D>(
      x: T|TensorLike, depthwiseFilter: Tensor4D|TensorLike4D,
      pointwiseFilter: Tensor4D|TensorLike, strides: [number, number]|number,
      pad: 'valid'|'same', dilation: [number, number]|number,
      dataFormat: 'NHWC'|'NCHW'): T;
  maxPool<T extends Tensor3D|Tensor4D>(
      x: T, filterSize: [number, number]|number,
      strides: [number, number]|number, pad: 'valid'|'same'|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T;
  avgPool<T extends Tensor3D|Tensor4D>(
      x: T, filterSize: [number, number]|number,
      strides: [number, number]|number, pad: 'valid'|'same'|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T;
  pool<T extends Tensor3D|Tensor4D>(
      input: T, windowShape: [number, number]|number, poolingType: 'avg'|'max',
      padding: 'valid'|'same'|number, diationRate?: [number, number]|number,
      strides?: [number, number]|number): T;
  localResponseNormalization<T extends Tensor3D|Tensor4D>(
      x: T, depthRadius: number, bias: number, alpha: number, beta: number): T;
  unsortedSegmentSum<T extends Tensor>(
      x: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T;
  batchToSpaceND<T extends Tensor>(
      x: T, blockShape: number[], crops: number[][]): T;
  spaceToBatchND<T extends Tensor>(
      x: T, blockShape: number[], paddings: number[][]): T;
  topk<T extends Tensor>(x: T, k: number, sorted: boolean):
      {values: T, indices: T};
  stridedSlice(
      x: Tensor, begin: number[], end: number[], strides: number[],
      beginMask: number, endMask: number, ellipsisMask: number,
      newAxisMask: number, shrinkAxisMask: number): Tensor;
  depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D;
  spectral: {
    fft(x: Tensor): Tensor; ifft(x: Tensor): Tensor; rfft(x: Tensor): Tensor;
    irfft(x: Tensor): Tensor
  };
}

// For tracking tensor creation and disposal.
let trackerFn: () => TensorTracker = null;
// Used by chaining methods to call into ops.
let opHandler: OpHandler = null;
// Used to warn about deprecated methods.
let deprecationWarningFn: (msg: string) => void = null;
// This here so that we can use this method on dev branches and keep the
// functionality at master.
// tslint:disable-next-line:no-unused-expression
[deprecationWarningFn];

/**
 * An external consumer can register itself as the tensor tracker. This way
 * the Tensor class can notify the tracker for every tensor created and
 * disposed.
 */
export function setTensorTracker(fn: () => TensorTracker) {
  trackerFn = fn;
}

/**
 * An external consumer can register itself as the op handler. This way the
 * Tensor class can have chaining methods that call into ops via the op handler.
 */
export function setOpHandler(handler: OpHandler) {
  opHandler = handler;
}

/**
 * Sets the deprecation warning function to be used by this file. This way the
 * Tensor class can be a leaf but still use the environment.
 */
export function setDeprecationWarningFn(fn: (msg: string) => void) {
  deprecationWarningFn = fn;
}

/**
 * We wrap data id since we use weak map to avoid memory leaks.
 * Since we have our own memory management, we have a reference counter
 * mapping a tensor to its data, so there is always a pointer (even if that
 * data is otherwise garbage collectable).
 * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/
 * Global_Objects/WeakMap
 */
export type DataId = object;  // object instead of {} to force non-primitive.

/**
 * A `tf.Tensor` object represents an immutable, multidimensional array of
 * numbers that has a shape and a data type.
 *
 * See `tf.tensor` for details on how to create a `tf.Tensor`.
 */
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
export class Tensor<R extends Rank = Rank> {
  /** Unique id of this tensor. */
  readonly id: number;
  /**
   * Id of the bucket holding the data for this tensor. Multiple arrays can
   * point to the same bucket (e.g. when calling array.reshape()).
   */
  dataId: DataId;
  /** The shape of the tensor. */
  readonly shape: ShapeMap[R];
  /** Number of elements in the tensor. */
  readonly size: number;
  /** The data type for the array. */
  readonly dtype: DataType;
  /** The rank type for the array (see `Rank` enum). */
  readonly rankType: R;

  /** Whether this tensor has been globally kept. */
  kept = false;
  /** The id of the scope this tensor is being tracked in. */
  scopeId: number;

  /**
   * Number of elements to skip in each dimension when indexing. See
   * https://docs.scipy.org/doc/numpy/reference/generated/\
   * numpy.ndarray.strides.html
   */
  readonly strides: number[];

  protected constructor(
      shape: ShapeMap[R], dtype: DataType, values?: BackendValues,
      dataId?: DataId, backend?: Backend) {
    this.shape = shape.slice() as ShapeMap[R];
    this.dtype = dtype || 'float32';
    this.size = util.sizeFromShape(shape);
    this.strides = computeStrides(shape);
    this.dataId = dataId != null ? dataId : {};
    this.id = trackerFn().nextTensorId();
    this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher') as R;
    trackerFn().registerTensor(this, backend);
    if (values != null) {
      trackerFn().write(backend, this.dataId, values);
    }
  }

  /**
   * Makes a new tensor with the provided shape and values. Values should be in
   * a flat array.
   */
  static make<T extends Tensor<R>, D extends DataType = 'float32',
                                             R extends Rank = Rank>(
      shape: ShapeMap[R], data: TensorData<D>, dtype?: D,
      backend?: Backend): T {
    let backendVals = data.values as BackendValues;
    if (data.values != null && dtype === 'string' &&
        util.isString(data.values[0])) {
      backendVals = (data.values as string[]).map(d => util.encodeString(d));
    }
    return new Tensor(shape, dtype, backendVals, data.dataId, backend) as T;
  }

  /** Flatten a Tensor to a 1D array. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  flatten(): Tensor1D {
    this.throwIfDisposed();
    return this.as1D();
  }

  /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  asScalar(): Scalar {
    this.throwIfDisposed();
    util.assert(this.size === 1, () => 'The array must have only 1 element.');
    return this.reshape<Rank.R0>([]);
  }

  /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  as1D(): Tensor1D {
    this.throwIfDisposed();
    return this.reshape<Rank.R1>([this.size]);
  }

  /**
   * Converts a `tf.Tensor` to a `tf.Tensor2D`.
   *
   * @param rows Number of rows in `tf.Tensor2D`.
   * @param columns Number of columns in `tf.Tensor2D`.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  as2D(rows: number, columns: number): Tensor2D {
    this.throwIfDisposed();
    return this.reshape<Rank.R2>([rows, columns]);
  }

  /**
   * Converts a `tf.Tensor` to a `tf.Tensor3D`.
   *
   * @param rows Number of rows in `tf.Tensor3D`.
   * @param columns Number of columns in `tf.Tensor3D`.
   * @param depth Depth of `tf.Tensor3D`.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  as3D(rows: number, columns: number, depth: number): Tensor3D {
    this.throwIfDisposed();
    return this.reshape<Rank.R3>([rows, columns, depth]);
  }

  /**
   * Converts a `tf.Tensor` to a `tf.Tensor4D`.
   *
   * @param rows Number of rows in `tf.Tensor4D`.
   * @param columns Number of columns in `tf.Tensor4D`.
   * @param depth Depth of `tf.Tensor4D`.
   * @param depth2 4th dimension of `tf.Tensor4D`.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  as4D(rows: number, columns: number, depth: number, depth2: number): Tensor4D {
    this.throwIfDisposed();
    return this.reshape<Rank.R4>([rows, columns, depth, depth2]);
  }

  /**
   * Converts a `tf.Tensor` to a `tf.Tensor5D`.
   *
   * @param rows Number of rows in `tf.Tensor5D`.
   * @param columns Number of columns in `tf.Tensor5D`.
   * @param depth Depth of `tf.Tensor5D`.
   * @param depth2 4th dimension of `tf.Tensor5D`.
   * @param depth3 5th dimension of 'tf.Tensor5D'
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  as5D(
      rows: number, columns: number, depth: number, depth2: number,
      depth3: number): Tensor5D {
    this.throwIfDisposed();
    return this.reshape<Rank.R5>([rows, columns, depth, depth2, depth3]);
  }

  /**
   * Casts a `tf.Tensor` to a specified dtype.
   *
   * @param dtype Data-type to cast the tensor to.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  asType<T extends this>(this: T, dtype: DataType): T {
    this.throwIfDisposed();
    return opHandler.cast(this, dtype) as T;
  }

  get rank(): number {
    return this.shape.length;
  }

  /** Returns a promise of `tf.TensorBuffer` that holds the underlying data. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  async buffer<D extends DataType = 'float32'>(): Promise<TensorBuffer<R, D>> {
    const vals = await this.data<D>();
    return opHandler.buffer(this.shape, this.dtype as D, vals);
  }

  /** Returns a `tf.TensorBuffer` that holds the underlying data. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  bufferSync<D extends DataType = 'float32'>(): TensorBuffer<R, D> {
    return opHandler.buffer(this.shape, this.dtype as D, this.dataSync());
  }

  /**
   * Returns the tensor data as a nested array. The transfer of data is done
   * asynchronously.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  async array(): Promise<ArrayMap[R]> {
    const vals = await this.data();
    return toNestedArray(this.shape, vals) as ArrayMap[R];
  }

  /**
   * Returns the tensor data as a nested array. The transfer of data is done
   * synchronously.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  arraySync(): ArrayMap[R] {
    return toNestedArray(this.shape, this.dataSync()) as ArrayMap[R];
  }

  /**
   * Asynchronously downloads the values from the `tf.Tensor`. Returns a promise
   * of `TypedArray` that resolves when the computation has finished.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  async data<D extends DataType = NumericDataType>(): Promise<DataTypeMap[D]> {
    this.throwIfDisposed();
    const data = trackerFn().read(this.dataId);
    if (this.dtype === 'string') {
      const bytes = await data as Uint8Array[];
      try {
        return bytes.map(b => util.decodeString(b));
      } catch {
        throw new Error(
            'Failed to decode the string bytes into utf-8. ' +
            'To get the original bytes, call tensor.bytes().');
      }
    }
    return data as Promise<DataTypeMap[D]>;
  }

  /**
   * Synchronously downloads the values from the `tf.Tensor`. This blocks the UI
   * thread until the values are ready, which can cause performance issues.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  dataSync<D extends DataType = NumericDataType>(): DataTypeMap[D] {
    this.throwIfDisposed();
    const data = trackerFn().readSync(this.dataId);
    if (this.dtype === 'string') {
      try {
        return (data as Uint8Array[]).map(b => util.decodeString(b));
      } catch {
        throw new Error(
            'Failed to decode the string bytes into utf-8. ' +
            'To get the original bytes, call tensor.bytes().');
      }
    }
    return data as DataTypeMap[D];
  }

  /** Returns the underlying bytes of the tensor's data. */
  async bytes(): Promise<Uint8Array[]|Uint8Array> {
    this.throwIfDisposed();
    const data = await trackerFn().read(this.dataId);
    if (this.dtype === 'string') {
      return data as Uint8Array[];
    } else {
      return new Uint8Array((data as TypedArray).buffer);
    }
  }

  /**
   * Disposes `tf.Tensor` from memory.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  dispose(): void {
    if (this.isDisposed) {
      return;
    }
    trackerFn().disposeTensor(this);
    this.isDisposedInternal = true;
  }

  protected isDisposedInternal = false;
  get isDisposed(): boolean {
    return this.isDisposedInternal;
  }

  private throwIfDisposed() {
    if (this.isDisposed) {
      throw new Error(`Tensor is disposed.`);
    }
  }

  /** Casts the array to type `float32` */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  toFloat<T extends this>(this: T): T {
    return this.asType('float32');
  }

  /** Casts the array to type `int32` */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  toInt() {
    return this.asType('int32');
  }

  /** Casts the array to type `bool` */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  toBool() {
    return this.asType('bool');
  }

  /**
   * Prints the `tf.Tensor`. See `tf.print` for details.
   *
   * @param verbose Whether to print verbose information about the tensor,
   *    including dtype and size.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  print(verbose = false): void {
    return opHandler.print(this, verbose);
  }

  /**
   * Reshapes the tensor into the provided shape.
   * See `tf.reshape` for more details.
   *
   * @param newShape An array of integers defining the output tensor shape.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  reshape<R2 extends Rank>(newShape: ShapeMap[R2]): Tensor<R2> {
    this.throwIfDisposed();
    return opHandler.reshape(this, newShape);
  }

  /**
   * Reshapes the tensor into the shape of the provided tensor.
   *
   * @param x The tensor of required shape.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  reshapeAs<T extends Tensor>(x: T): T {
    this.throwIfDisposed();
    return this.reshape(x.shape) as T;
  }

  /**
   * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
   * into the tensor's shape. See `tf.expandDims` for details.
   *
   * @param axis The dimension index at which to insert shape of 1. Defaults to
   *    0 (the first dimension).
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  expandDims<R2 extends Rank>(axis = 0): Tensor<R2> {
    return opHandler.expandDims(this, axis);
  }

  /**
   * Returns the cumulative sum of the `tf.Tensor` along `axis`.
   *
   * @param axis The axis along which to sum. Optional. Defaults to 0.
   * @param exclusive Whether to perform exclusive cumulative sum. Defaults to
   *    false. If set to true then the sum of each tensor entry does not include
   *    its own value, but only the values previous to it along the specified
   *    axis.
   * @param reverse Whether to sum in the opposite direction. Defaults to
   *    false.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  cumsum<T extends Tensor>(axis = 0, exclusive = false, reverse = false): T {
    return opHandler.cumsum(this, axis, exclusive, reverse);
  }

  /**
   * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape.
   * See `tf.squeeze` for more details.
   *
   * @param axis A list of numbers. If specified, only squeezes the
   *    dimensions listed. The dimension index starts at 0. It is an error to
   *    squeeze a dimension that is not 1.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  squeeze<T extends Tensor>(axis?: number[]): T {
    this.throwIfDisposed();
    return opHandler.squeeze(this, axis);
  }

  /** Returns a copy of the tensor. See `tf.clone` for details. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  clone<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.clone(this);
  }

  oneHot(this: Tensor, depth: number, onValue?: number, offValue?: number):
      Tensor {
    this.throwIfDisposed();
    return opHandler.oneHot(this, depth, onValue, offValue);
  }

  /** Returns a human-readable description of the tensor. Useful for logging. */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  toString(verbose = false): string {
    const vals = this.dataSync();
    return tensorToString(vals, this.shape, this.dtype, verbose);
  }

  // Below is chain API that is not exposed to docs to avoid repetition. To
  // expose a method, move it above this comment and add @doc and jsdoc.

  tile<T extends this>(this: T, reps: number[]): T {
    this.throwIfDisposed();
    return opHandler.tile(this, reps) as T;
  }

  gather<T extends this>(this: T, indices: Tensor|TensorLike, axis = 0): T {
    this.throwIfDisposed();
    return opHandler.gather(this, indices, axis) as T;
  }

  matMul<T extends Tensor>(
      this: T, b: T|TensorLike, transposeA = false, transposeB = false): T {
    this.throwIfDisposed();
    return opHandler.matMul(this, b, transposeA, transposeB);
  }
  dot(b: Tensor|TensorLike): Tensor {
    this.throwIfDisposed();
    return opHandler.dot(this, b);
  }
  norm(
      ord: number|'euclidean'|'fro' = 'euclidean', axis: number|number[] = null,
      keepDims = false): Tensor {
    this.throwIfDisposed();
    return opHandler.norm(this, ord, axis, keepDims);
  }
  slice<T extends Tensor<R>>(
      this: T, begin: number|number[], size?: number|number[]): T {
    this.throwIfDisposed();
    return opHandler.slice(this, begin, size);
  }
  reverse<T extends Tensor>(this: T, axis?: number|number[]): T {
    this.throwIfDisposed();
    return opHandler.reverse(this, axis);
  }
  concat<T extends Tensor>(this: T, x: T|Array<T|TensorLike>, axis = 0): T {
    this.throwIfDisposed();
    if (x instanceof Tensor) {
      x = [x];
    }
    return opHandler.concat([this, ...x], axis);
  }
  split<T extends Tensor>(this: T, numOrSizeSplits: number[]|number, axis = 0):
      T[] {
    this.throwIfDisposed();
    return opHandler.split(this, numOrSizeSplits, axis);
  }
  stack(x: Tensor, axis = 0): Tensor {
    return opHandler.stack([this, x], axis);
  }
  unstack(axis = 0): Tensor[] {
    return opHandler.unstack(this, axis);
  }
  pad<T extends Tensor>(
      this: T, paddings: Array<[number, number]>, constantValue = 0): T {
    return opHandler.pad(this, paddings, constantValue);
  }
  /**
   * @deprecated Use `tf.batchNorm` instead, and note the positional argument
   *     change of scale, offset, and varianceEpsilon.
   */
  batchNormalization(
      mean: Tensor<R>|Tensor1D|TensorLike,
      variance: Tensor<R>|Tensor1D|TensorLike, varianceEpsilon = .001,
      scale?: Tensor<R>|Tensor1D|TensorLike,
      offset?: Tensor<R>|Tensor1D|TensorLike): Tensor<R> {
    deprecationWarningFn(
        'tf.batchNormalization() is going away. ' +
        'Use tf.batchNorm() instead, and note the positional argument change ' +
        'of scale, offset, and varianceEpsilon');
    return this.batchNorm(mean, variance, offset, scale, varianceEpsilon);
  }

  batchNorm(
      mean: Tensor<R>|Tensor1D|TensorLike,
      variance: Tensor<R>|Tensor1D|TensorLike,
      offset?: Tensor<R>|Tensor1D|TensorLike,
      scale?: Tensor<R>|Tensor1D|TensorLike,
      varianceEpsilon = .001,
      ): Tensor<R> {
    this.throwIfDisposed();
    return opHandler.batchNorm(
        this, mean, variance, offset, scale, varianceEpsilon);
  }
  // Reduction ops.
  all<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.all(this, axis, keepDims);
  }
  any<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.any(this, axis, keepDims);
  }
  logSumExp<T extends Tensor>(axis: number|number[] = null, keepDims = false):
      T {
    this.throwIfDisposed();
    return opHandler.logSumExp(this, axis, keepDims);
  }
  sum<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.sum(this, axis, keepDims);
  }
  prod<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.prod(this, axis, keepDims);
  }
  mean<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.mean(this, axis, keepDims);
  }
  min<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.min(this, axis, keepDims);
  }
  max<T extends Tensor>(axis: number|number[] = null, keepDims = false): T {
    this.throwIfDisposed();
    return opHandler.max(this, axis, keepDims);
  }
  argMin<T extends Tensor>(axis: number = null): T {
    this.throwIfDisposed();
    return opHandler.argMin(this, axis);
  }
  argMax<T extends Tensor>(axis: number = null): T {
    this.throwIfDisposed();
    return opHandler.argMax(this, axis);
  }

  // Transformations
  cast<T extends this>(dtype: DataType): T {
    this.throwIfDisposed();
    return opHandler.cast(this as T, dtype) as T;
  }

  // Binary ops.

  add<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.add(this, x);
  }
  addStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.addStrict(this, x) as T;
  }
  atan2<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.atan2(this, x) as T;
  }
  sub<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.sub(this, x);
  }
  subStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.subStrict(this, x) as T;
  }
  pow<T extends Tensor>(this: T, exp: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.pow(this, exp);
  }
  powStrict(exp: Tensor|TensorLike): Tensor<R> {
    this.throwIfDisposed();
    return opHandler.powStrict(this, exp);
  }
  mul<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.mul(this, x);
  }
  mulStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.mulStrict(this, x) as T;
  }
  div<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.div(this, x);
  }
  floorDiv<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.floorDiv(this, x);
  }
  divStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.divStrict(this, x) as T;
  }
  minimum<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.minimum(this, x);
  }
  minimumStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.minimumStrict(this, x) as T;
  }
  maximum<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.maximum(this, x);
  }
  maximumStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.maximumStrict(this, x) as T;
  }
  mod<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.mod(this, x);
  }
  modStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.modStrict(this, x) as T;
  }
  squaredDifference<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.squaredDifference(this, x);
  }
  squaredDifferenceStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.squaredDifferenceStrict(this, x) as T;
  }
  transpose<T extends Tensor>(this: T, perm?: number[]): T {
    this.throwIfDisposed();
    return opHandler.transpose(this, perm);
  }

  // Compare ops.

  notEqual<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.notEqual(this, x);
  }
  notEqualStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.notEqualStrict(this, x) as T;
  }
  less<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.less(this, x);
  }
  lessStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.lessStrict(this, x) as T;
  }
  equal<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.equal(this, x);
  }
  equalStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.equalStrict(this, x) as T;
  }
  lessEqual<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.lessEqual(this, x);
  }
  lessEqualStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.lessEqualStrict(this, x) as T;
  }
  greater<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.greater(this, x);
  }
  greaterStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.greaterStrict(this, x) as T;
  }
  greaterEqual<T extends Tensor>(x: Tensor|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.greaterEqual(this, x);
  }
  greaterEqualStrict<T extends this>(this: T, x: T|TensorLike): T {
    this.throwIfDisposed();
    return opHandler.greaterEqualStrict(this, x) as T;
  }

  // Compare ops.
  logicalAnd(x: Tensor|TensorLike): Tensor {
    this.throwIfDisposed();
    return opHandler.logicalAnd(this, x);
  }
  logicalOr(x: Tensor|TensorLike): Tensor {
    this.throwIfDisposed();
    return opHandler.logicalOr(this, x);
  }
  logicalNot<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.logicalNot(this);
  }
  logicalXor(x: Tensor|TensorLike): Tensor {
    this.throwIfDisposed();
    return opHandler.logicalXor(this, x);
  }
  where(condition: Tensor|TensorLike, x: Tensor|TensorLike): Tensor {
    this.throwIfDisposed();
    return opHandler.where(condition, this, x);
  }

  // Unary ops.
  neg<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.neg(this);
  }
  ceil<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.ceil(this);
  }
  floor<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.floor(this);
  }
  sign<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.sign(this);
  }
  isNaN<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.isNaN(this);
  }
  isInf<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.isInf(this);
  }
  isFinite<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.isFinite(this);
  }
  exp<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.exp(this);
  }
  expm1<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.expm1(this);
  }
  log<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.log(this);
  }
  log1p<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.log1p(this);
  }
  sqrt<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.sqrt(this);
  }
  rsqrt<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.rsqrt(this);
  }
  square<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.square(this);
  }
  reciprocal<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.reciprocal(this);
  }
  abs<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.abs(this);
  }
  clipByValue(min: number, max: number): Tensor<R> {
    this.throwIfDisposed();
    return opHandler.clipByValue(this, min, max);
  }
  relu<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.relu(this);
  }
  elu<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.elu(this);
  }
  selu<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.selu(this);
  }
  leakyRelu(alpha = 0.2): Tensor<R> {
    this.throwIfDisposed();
    return opHandler.leakyRelu(this, alpha);
  }
  prelu(alpha: Tensor<R>|TensorLike): Tensor<R> {
    this.throwIfDisposed();
    return opHandler.prelu(this, alpha);
  }
  sigmoid<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.sigmoid(this);
  }
  logSigmoid<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.logSigmoid(this);
  }
  softplus<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.softplus(this);
  }
  zerosLike<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.zerosLike(this);
  }
  onesLike<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.onesLike(this);
  }
  sin<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.sin(this);
  }
  cos<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.cos(this);
  }
  tan<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.tan(this);
  }
  asin<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.asin(this);
  }
  acos<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.acos(this);
  }
  atan<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.atan(this);
  }
  sinh<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.sinh(this);
  }
  cosh<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.cosh(this);
  }
  tanh<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.tanh(this);
  }
  asinh<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.asinh(this);
  }
  acosh<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.acosh(this);
  }
  atanh<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.atanh(this);
  }
  erf<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.erf(this);
  }
  round<T extends Tensor>(this: T): T {
    this.throwIfDisposed();
    return opHandler.round(this);
  }
  step<T extends Tensor>(this: T, alpha = 0.0): T {
    this.throwIfDisposed();
    return opHandler.step(this, alpha);
  }
  softmax<T extends this>(this: T, dim = -1): T {
    this.throwIfDisposed();
    return opHandler.softmax(this, dim) as T;
  }
  logSoftmax<T extends this>(this: T, axis = -1): T {
    this.throwIfDisposed();
    return opHandler.logSoftmax(this, axis) as T;
  }

  // Image ops.
  resizeBilinear<T extends Tensor3D|Tensor4D>(
      this: T, newShape2D: [number, number], alignCorners = false): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.image.resizeBilinear(this, newShape2D, alignCorners);
  }

  resizeNearestNeighbor<T extends Tensor3D|Tensor4D>(
      this: T, newShape2D: [number, number], alignCorners = false): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.image.resizeNearestNeighbor(
        this, newShape2D, alignCorners);
  }

  // Convolutions.
  conv1d<T extends Tensor2D|Tensor3D>(
      this: T, filter: Tensor3D|TensorLike3D, stride: number,
      pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.conv1d(
        this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
  }
  conv2d<T extends Tensor3D|Tensor4D>(
      this: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,
      pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
      dilations: [number, number]|number = [1, 1],
      dimRoundingMode?: 'floor'|'round'|'ceil'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.conv2d(
        this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
  }
  conv2dTranspose<T extends Tensor3D|Tensor4D>(
      this: T, filter: Tensor4D|TensorLike4D,
      outputShape: [number, number, number, number]|[number, number, number],
      strides: [number, number]|number, pad: 'valid'|'same'|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.conv2dTranspose(
        this, filter, outputShape, strides, pad, dimRoundingMode);
  }
  depthwiseConv2D<T extends Tensor3D|Tensor4D>(
      this: T, filter: Tensor4D|TensorLike4D, strides: [number, number]|number,
      pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
      dilations: [number, number]|number = [1, 1],
      dimRoundingMode?: 'floor'|'round'|'ceil'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.depthwiseConv2d(
        this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
  }

  separableConv2d<T extends Tensor3D|Tensor4D>(
      this: T|TensorLike, depthwiseFilter: Tensor4D|TensorLike4D,
      pointwiseFilter: Tensor4D|TensorLike, strides: [number, number]|number,
      pad: 'valid'|'same', dilation: [number, number]|number = [1, 1],
      dataFormat: 'NHWC'|'NCHW' = 'NHWC'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.separableConv2d(
        this, depthwiseFilter, pointwiseFilter, strides, pad, dilation,
        dataFormat);
  }

  // Pooling.
  avgPool<T extends Tensor3D|Tensor4D>(
      this: T, filterSize: [number, number]|number,
      strides: [number, number]|number, pad: 'valid'|'same'|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.avgPool(this, filterSize, strides, pad, dimRoundingMode);
  }
  maxPool<T extends Tensor3D|Tensor4D>(
      this: T, filterSize: [number, number]|number,
      strides: [number, number]|number, pad: 'valid'|'same'|number,
      dimRoundingMode?: 'floor'|'round'|'ceil'): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.maxPool(this, filterSize, strides, pad, dimRoundingMode);
  }
  localResponseNormalization<T extends Tensor3D|Tensor4D>(
      this: T, radius = 5, bias = 1, alpha = 1, beta = 0.5): T {
    return opHandler.localResponseNormalization(
        this, radius, bias, alpha, beta);
  }
  pool<T extends Tensor3D|Tensor4D>(
      this: T, windowShape: [number, number]|number, poolingType: 'max'|'avg',
      padding: 'valid'|'same'|number, dilationRate?: [number, number]|number,
      strides?: [number, number]|number): T {
    (this as Tensor).throwIfDisposed();
    return opHandler.pool(
        this, windowShape, poolingType, padding, dilationRate, strides);
  }

  variable(trainable = true, name?: string, dtype?: DataType): Variable<R> {
    this.throwIfDisposed();
    return Variable.variable(this, trainable, name, dtype);
  }

  unsortedSegmentSum<T extends Tensor>(
      this: T, segmentIds: Tensor1D|TensorLike1D, numSegments: number): T {
    this.throwIfDisposed();
    return opHandler.unsortedSegmentSum(this, segmentIds, numSegments);
  }

  batchToSpaceND<T extends Tensor>(
      this: T, blockShape: number[], crops: number[][]): T {
    this.throwIfDisposed();
    return opHandler.batchToSpaceND(this, blockShape, crops);
  }

  spaceToBatchND<T extends Tensor>(
      this: T, blockShape: number[], paddings: number[][]): T {
    this.throwIfDisposed();
    return opHandler.spaceToBatchND(this, blockShape, paddings);
  }

  topk<T extends Tensor>(this: T, k = 1, sorted = true):
      {values: T, indices: T} {
    this.throwIfDisposed();
    return opHandler.topk(this, k, sorted);
  }

  stridedSlice(
      this: Tensor, begin: number[], end: number[], strides: number[],
      beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,
      shrinkAxisMask = 0): Tensor {
    this.throwIfDisposed();
    return opHandler.stridedSlice(
        this, begin, end, strides, beginMask, endMask, ellipsisMask,
        newAxisMask, shrinkAxisMask);
  }

  depthToSpace(this: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):
      Tensor4D {
    this.throwIfDisposed();
    return opHandler.depthToSpace(this, blockSize, dataFormat);
  }

  fft(this: Tensor): Tensor {
    this.throwIfDisposed();
    return opHandler.spectral.fft(this);
  }

  ifft(this: Tensor): Tensor {
    this.throwIfDisposed();
    return opHandler.spectral.ifft(this);
  }

  rfft(this: Tensor): Tensor {
    this.throwIfDisposed();
    return opHandler.spectral.rfft(this);
  }

  irfft(this: Tensor): Tensor {
    this.throwIfDisposed();
    return opHandler.spectral.irfft(this);
  }
}
Object.defineProperty(Tensor, Symbol.hasInstance, {
  value: (instance: Tensor) => {
    return !!instance && instance.dataId != null && instance.shape != null &&
        instance.dtype != null;
  }
});

export interface NumericTensor<R extends Rank = Rank> extends Tensor<R> {
  dtype: NumericDataType;
  dataSync<D extends DataType = NumericDataType>(): DataTypeMap[D];
  data<D extends DataType = NumericDataType>(): Promise<DataTypeMap[D]>;
}

export interface StringTensor<R extends Rank = Rank> extends Tensor<R> {
  dtype: 'string';
  dataSync<D extends DataType = 'string'>(): DataTypeMap[D];
  data<D extends DataType = 'string'>(): Promise<DataTypeMap[D]>;
}

/** @doclink Tensor */
export type Scalar = Tensor<Rank.R0>;
/** @doclink Tensor */
export type Tensor1D = Tensor<Rank.R1>;
/** @doclink Tensor */
export type Tensor2D = Tensor<Rank.R2>;
/** @doclink Tensor */
export type Tensor3D = Tensor<Rank.R3>;
/** @doclink Tensor */
export type Tensor4D = Tensor<Rank.R4>;
/** @doclink Tensor */
export type Tensor5D = Tensor<Rank.R5>;
/** @doclink Tensor */
export type Tensor6D = Tensor<Rank.R6>;

/**
 * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
 */
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
export class Variable<R extends Rank = Rank> extends Tensor<R> {
  name: string;

  /**
   * Private constructor since we cannot add logic before calling `super()`.
   * Instead, we expose static `Variable.variable` method below, which will be
   * added to global namespace.
   */
  private constructor(
      initialValue: Tensor<R>, public trainable = true, name?: string) {
    super(
        initialValue.shape, initialValue.dtype, null /* values */,
        initialValue.dataId);
    this.name = name;
    if (this.name == null) {
      this.name = trackerFn().nextVariableId().toString();
    }
    try {
      trackerFn().registerVariable(this);
    } catch (ex) {
      trackerFn().disposeTensor(this);
      throw ex;
    }
  }

  /**
   * Creates a new variable with the provided initial value.
   * ```js
   * const x = tf.variable(tf.tensor([1, 2, 3]));
   * x.assign(tf.tensor([4, 5, 6]));
   *
   * x.print();
   * ```
   *
   * @param initialValue Initial value for the tensor.
   * @param trainable If true, optimizers are allowed to update it.
   * @param name Name of the variable. Defaults to a unique id.
   * @param dtype If set, initialValue will be converted to the given type.
   */
  /** @doc {heading: 'Tensors', subheading: 'Creation'} */
  static variable<R extends Rank>(
      initialValue: Tensor<R>, trainable = true, name?: string,
      dtype?: DataType): Variable<R> {
    if (dtype != null && dtype !== initialValue.dtype) {
      initialValue = initialValue.asType(dtype) as Tensor<R>;
    }
    return new Variable(initialValue, trainable, name);
  }

  /**
   * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
   * the same shape and dtype as the old `tf.Tensor`.
   *
   * @param newValue New tensor to be assigned to this variable.
   */
  /** @doc {heading: 'Tensors', subheading: 'Classes'} */
  assign(newValue: Tensor<R>): void {
    if (newValue.dtype !== this.dtype) {
      throw new Error(
          `dtype of the new value (${newValue.dtype}) and ` +
          `previous value (${this.dtype}) must match`);
    }
    if (!util.arraysEqual(newValue.shape, this.shape)) {
      throw new Error(
          `shape of the new value (${newValue.shape}) and ` +
          `previous value (${this.shape}) must match`);
    }
    trackerFn().disposeTensor(this);
    this.dataId = newValue.dataId;
    trackerFn().registerTensor(this);
  }

  dispose(): void {
    trackerFn().disposeVariable(this);
    this.isDisposedInternal = true;
  }
}

Object.defineProperty(Variable, Symbol.hasInstance, {
  value: (instance: Variable) => {
    return instance instanceof Tensor && instance.assign != null &&
        instance.assign instanceof Function;
  }
});

const variable = Variable.variable;
export {variable};
