UNPKG

11 kBJavaScriptView Raw
1'use strict';
2
3Object.defineProperty(exports, "__esModule", {
4 value: true
5});
6exports.OneVsAllClassifier = exports.Classifier = exports.Estimator = undefined;
7
8var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); // Standard imports
9
10
11var _arrays = require('../arrays');
12
13var Arrays = _interopRequireWildcard(_arrays);
14
15function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
16
17function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; }
18
19function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; }
20
21function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
22
23/**
24 * Base class for supervised estimators (classifiers or regression models).
25 */
26var Estimator = exports.Estimator = function () {
27 function Estimator() {
28 _classCallCheck(this, Estimator);
29 }
30
31 _createClass(Estimator, [{
32 key: 'train',
33
34 /**
35 * Train the supervised learning algorithm on a dataset.
36 *
37 * @abstract
38 *
39 * @param {Array.<Array.<number>>} X - Features per data point
40 * @param {Array.<mixed>} y Class labels per data point
41 */
42 value: function train(X, y) {
43 throw new Error('Method must be implemented child class.');
44 }
45
46 /**
47 * Make a prediction for a data set.
48 *
49 * @abstract
50 *
51 * @param {Array.<Array.<number>>} X - Features for each data point
52 * @return {Array.<mixed>} Predictions. Label of class with highest prevalence among k nearest
53 * neighbours for each sample
54 */
55
56 }, {
57 key: 'test',
58 value: function test(X) {
59 throw new Error('Method must be implemented child class.');
60 }
61 }]);
62
63 return Estimator;
64}();
65
66/**
67 * Base class for classifiers.
68 */
69
70
71var Classifier = exports.Classifier = function (_Estimator) {
72 _inherits(Classifier, _Estimator);
73
74 function Classifier() {
75 _classCallCheck(this, Classifier);
76
77 return _possibleConstructorReturn(this, (Classifier.__proto__ || Object.getPrototypeOf(Classifier)).apply(this, arguments));
78 }
79
80 return Classifier;
81}(Estimator);
82
83/**
84 * Base class for multiclass classifiers using the one-vs-all classification method. For a training
85 * set with k unique class labels, the one-vs-all classifier creates k binary classifiers. Each of
86 * these classifiers is trained on the entire data set, where the i-th classifier treats all samples
87 * that do not come from the i-th class as being from the same class. In the prediction phase, the
88 * one-vs-all classifier runs all k binary classifiers on the test data point, and predicts the
89 * class that has the highest normalized prediction value
90 */
91
92
93var OneVsAllClassifier = exports.OneVsAllClassifier = function (_Classifier) {
94 _inherits(OneVsAllClassifier, _Classifier);
95
96 function OneVsAllClassifier() {
97 _classCallCheck(this, OneVsAllClassifier);
98
99 return _possibleConstructorReturn(this, (OneVsAllClassifier.__proto__ || Object.getPrototypeOf(OneVsAllClassifier)).apply(this, arguments));
100 }
101
102 _createClass(OneVsAllClassifier, [{
103 key: 'createClassifier',
104
105 /**
106 * Create a binary classifier for one of the classes.
107 *
108 * @abstract
109 *
110 * @param {number} classIndex - Class index of the positive class for the binary classifier
111 * @return {BinaryClassifier} Binary classifier
112 */
113 value: function createClassifier(classIndex) {
114 throw new Error('Method must be implemented child class.');
115 }
116
117 /**
118 * Create all binary classifiers. Creates one classifier per class.
119 *
120 * @param {Array.<number>} y - Class labels for the training data
121 */
122
123 }, {
124 key: 'createClassifiers',
125 value: function createClassifiers(y) {
126 var _this3 = this;
127
128 // Get unique labels
129 var uniqueClassIndices = Arrays.unique(y);
130
131 // Initialize label set and classifier for all labels
132 this.classifiers = uniqueClassIndices.map(function (classIndex) {
133 var classifier = _this3.createClassifier();
134
135 return {
136 classIndex: classIndex,
137 classifier: classifier
138 };
139 });
140 }
141
142 /**
143 * Get the class labels corresponding with each internal class label. Can be used to determine
144 * which predictino is for which class in predictProba.
145 *
146 * @return {Array.<number>} The n-th element in this array contains the class label of what is
147 * internally class n
148 */
149
150 }, {
151 key: 'getClasses',
152 value: function getClasses() {
153 return this.classifiers.map(function (x, i) {
154 return x;
155 });
156 }
157
158 /**
159 * Train all binary classifiers one-by-one
160 *
161 * @param {Array.<Array.<number>>} X - Features per data point
162 * @param {Array.<mixed>} y Class labels per data point
163 */
164
165 }, {
166 key: 'trainBatch',
167 value: function trainBatch(X, y) {
168 this.classifiers.forEach(function (classifier) {
169 var yOneVsAll = y.map(function (classIndex) {
170 return classIndex === classifier.classIndex ? 1 : 0;
171 });
172 classifier.classifier.train(X, yOneVsAll);
173 });
174 }
175
176 /**
177 * Train all binary classifiers iteration by iteration, i.e. start with the first training
178 * iteration for each binary classifier, then execute the second training iteration for each
179 * binary classifier, and so forth. Can be used when one needs to keep track of information per
180 * iteration, e.g. accuracy
181 */
182
183 }, {
184 key: 'trainIterative',
185 value: function trainIterative() {
186 var remainingClassIndices = Arrays.unique(this.training.labels);
187
188 var epoch = 0;
189
190 while (epoch < 100 && remainingClassIndices.length > 0) {
191 var remainingClassIndicesNew = remainingClassIndices.slice();
192
193 // Loop over all 1-vs-all classifiers
194 var _iteratorNormalCompletion = true;
195 var _didIteratorError = false;
196 var _iteratorError = undefined;
197
198 try {
199 for (var _iterator = remainingClassIndices[Symbol.iterator](), _step; !(_iteratorNormalCompletion = (_step = _iterator.next()).done); _iteratorNormalCompletion = true) {
200 var classIndex = _step.value;
201
202 // Run a single iteration for the classifier
203 this.classifiers[classIndex].trainIteration();
204
205 if (this.classifiers[classIndex].checkConvergence()) {
206 remainingClassIndicesNew.splice(remainingClassIndicesNew.indexOf(classIndex), 1);
207 }
208 }
209 } catch (err) {
210 _didIteratorError = true;
211 _iteratorError = err;
212 } finally {
213 try {
214 if (!_iteratorNormalCompletion && _iterator.return) {
215 _iterator.return();
216 }
217 } finally {
218 if (_didIteratorError) {
219 throw _iteratorError;
220 }
221 }
222 }
223
224 remainingClassIndices = remainingClassIndicesNew;
225
226 // Emit event the outside can hook into
227 this.emit('iterationCompleted');
228
229 epoch += 1;
230 }
231
232 // Emit event the outside can hook into
233 this.emit('converged');
234 }
235
236 /**
237 * @see {Classifier#predict}
238 */
239
240 }, {
241 key: 'predict',
242 value: function predict(X) {
243 // Get predictions from all classifiers for all data points by predicting all data points with
244 // each classifier (getting an array of predictions for each classifier) and transposing
245 var datapointsPredictions = Arrays.transpose(this.classifiers.map(function (classifier) {
246 return classifier.classifier.predict(X, { output: 'normalized' });
247 }));
248
249 // Form final prediction by taking index of maximum normalized classifier output
250 return datapointsPredictions.map(function (x) {
251 return Arrays.argMax(x);
252 });
253 }
254
255 /**
256 * Make a probabilistic prediction for a data set.
257 *
258 * @param {Array.Array.<number>} features - Features for each data point
259 * @return {Array.Array.<number>} Probability predictions. Each array element contains the
260 * probability of that particular class. The array elements are ordered in the order the classes
261 * appear in the training data (i.e., if class "A" occurs first in the labels list in the
262 * training, procedure, its probability is returned in the first array element of each
263 * sub-array)
264 */
265
266 }, {
267 key: 'predictProba',
268 value: function predictProba(X) {
269 if (typeof this.classifiers[0].classifier.predictProba !== 'function') {
270 throw new Error('Base classifier does not implement the predictProba method, which was attempted to be called from the one-vs-all classifier.');
271 }
272
273 // Get probability predictions from all classifiers for all data points by predicting all data
274 // points with each classifier (getting an array of predictions for each classifier) and
275 // transposing
276 var predictions = Arrays.transpose(this.classifiers.map(function (classifier) {
277 return classifier.classifier.predictProba(X).map(function (probs) {
278 return probs[1];
279 });
280 }));
281
282 // Scale all predictions to yield valid probabilities
283 return predictions.map(function (x) {
284 return Arrays.scale(x, 1 / Arrays.internalSum(x));
285 });
286 }
287
288 /**
289 * Retrieve the individual binary one-vs-all classifiers.
290 *
291 * @return {Array.<Classifier>} List of binary one-vs-all classifiers used as the base classifiers
292 * for this multiclass classifier
293 */
294
295 }, {
296 key: 'getClassifiers',
297 value: function getClassifiers() {
298 return this.classifiers;
299 }
300 }]);
301
302 return OneVsAllClassifier;
303}(Classifier);
\No newline at end of file