UNPKG

4.22 kBJavaScriptView Raw
1'use strict';
2
3Object.defineProperty(exports, "__esModule", {
4 value: true
5});
6
7var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; /* eslint import/prefer-default-export: "off" */
8
9// Internal dependencies
10
11
12exports.trainTestSplit = trainTestSplit;
13
14var _random = require('../random');
15
16var Random = _interopRequireWildcard(_random);
17
18function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
19
20function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i = 0, arr2 = Array(arr.length); i < arr.length; i++) { arr2[i] = arr[i]; } return arr2; } else { return Array.from(arr); } }
21
22/**
23 * Split a dataset into a training and a test set.
24 *
25 * @example <caption>Example with n=5 datapoints and d=2 features per sample</caption>
26 * // n x d array of features
27 * var X = [[0, 0], [0.5, 0.2], [0.3, 2.5], [0.8, 0.9], [0.7, 0.2]];
28 *
29 * // n-dimensional array of labels
30 * var y = [1, 0, 0, 1, 1]; // n-dimensional array of labels
31 *
32 * // Split into training and test set
33 * var [X_train, y_train, X_test, y_test] = trainTestSplit([X, y], {trainSize: 0.8});
34 *
35 * // Now, X_train and y_train will contain the features and labels of the training set,
36 * // respectively, and X_test and y_test will contain the features and labels of the test set.
37 *
38 * // Depending on the random seed, the result might be the following
39 * X_train: [[0, 0], [0.5, 0.2], [0.3, 2.5], [0.7, 0.2]]
40 * y_train: [1, 0, 0, 1]
41 * X_test: [[0.8, 0.9]]
42 * y_test: [1]
43 *
44 * @param {Array.<Array.<mixed>>} input - List of input arrays. The input arrays should have the
45 * same length (i.e., they should have the same first dimension size)
46 * @param {Object} optionsUser - User-defined options. See method implementation for details
47 * @param {number} [optionsUser.trainSize = 0.8] - Size of the training set. If int, this exact
48 * number of training samples is used. If float, the total number of elements times the float
49 * number is used as the number of training elements
50 * @return {Array} List of output arrays. The number of elements is 2 times the number of input
51 * elements. For each input element, a pair of output elements is returned.
52 */
53function trainTestSplit(input) {
54 var optionsUser = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
55
56 // Options
57 var optionsDefault = {
58 trainSize: 0.8
59 };
60
61 var options = _extends({}, optionsDefault, optionsUser);
62
63 // Total number of elements
64 var numElements = input[0].length;
65
66 // Check whether all input data sets have the same size
67 if (!input.every(function (x) {
68 return x.length === input[0].length;
69 })) {
70 throw new Error('All input arrays should have the same length (i.e., the size of their\n first dimensions should be the same');
71 }
72
73 // Generate list of all possible array indices
74 var indices = [].concat(_toConsumableArray(Array(numElements).keys()));
75
76 // Number of training elements
77 var numTrainElements = Math.round(numElements * options.trainSize);
78
79 // Take a random sample from the list of possible indices, which are then used as the indices
80 // of the elements to use for the training data
81 var trainIndices = Random.sample(indices, numTrainElements, false);
82
83 // Create resulting training and test sets
84 var trainArrays = input.map(function (trainArray) {
85 return trainArray.filter(function (x, i) {
86 return trainIndices.includes(i);
87 }).map(function (x) {
88 return Array.isArray(x) ? x.slice() : x;
89 });
90 });
91
92 var testArrays = input.map(function (testArray) {
93 return testArray.filter(function (x, i) {
94 return !trainIndices.includes(i);
95 }).map(function (x) {
96 return Array.isArray(x) ? x.slice() : x;
97 });
98 });
99
100 // Return train and test sets
101 return [].concat(_toConsumableArray(trainArrays), _toConsumableArray(testArrays));
102}
\No newline at end of file