1 | 'use strict';
|
2 |
|
3 | Object.defineProperty(exports, "__esModule", {
|
4 | value: true
|
5 | });
|
6 | exports.OneVsAllClassifier = exports.Classifier = exports.Estimator = undefined;
|
7 |
|
8 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
|
9 |
|
10 |
|
11 | var _arrays = require('../arrays');
|
12 |
|
13 | var Arrays = _interopRequireWildcard(_arrays);
|
14 |
|
15 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
|
16 |
|
17 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; }
|
18 |
|
19 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; }
|
20 |
|
21 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
|
22 |
|
23 |
|
24 |
|
25 |
|
26 | var Estimator = exports.Estimator = function () {
|
27 | function Estimator() {
|
28 | _classCallCheck(this, Estimator);
|
29 | }
|
30 |
|
31 | _createClass(Estimator, [{
|
32 | key: 'train',
|
33 |
|
34 | |
35 |
|
36 |
|
37 |
|
38 |
|
39 |
|
40 |
|
41 |
|
42 | value: function train(X, y) {
|
43 | throw new Error('Method must be implemented child class.');
|
44 | }
|
45 |
|
46 | |
47 |
|
48 |
|
49 |
|
50 |
|
51 |
|
52 |
|
53 |
|
54 |
|
55 |
|
56 | }, {
|
57 | key: 'test',
|
58 | value: function test(X) {
|
59 | throw new Error('Method must be implemented child class.');
|
60 | }
|
61 | }]);
|
62 |
|
63 | return Estimator;
|
64 | }();
|
65 |
|
66 |
|
67 |
|
68 |
|
69 |
|
70 |
|
71 | var Classifier = exports.Classifier = function (_Estimator) {
|
72 | _inherits(Classifier, _Estimator);
|
73 |
|
74 | function Classifier() {
|
75 | _classCallCheck(this, Classifier);
|
76 |
|
77 | return _possibleConstructorReturn(this, (Classifier.__proto__ || Object.getPrototypeOf(Classifier)).apply(this, arguments));
|
78 | }
|
79 |
|
80 | return Classifier;
|
81 | }(Estimator);
|
82 |
|
83 |
|
84 |
|
85 |
|
86 |
|
87 |
|
88 |
|
89 |
|
90 |
|
91 |
|
92 |
|
93 | var OneVsAllClassifier = exports.OneVsAllClassifier = function (_Classifier) {
|
94 | _inherits(OneVsAllClassifier, _Classifier);
|
95 |
|
96 | function OneVsAllClassifier() {
|
97 | _classCallCheck(this, OneVsAllClassifier);
|
98 |
|
99 | return _possibleConstructorReturn(this, (OneVsAllClassifier.__proto__ || Object.getPrototypeOf(OneVsAllClassifier)).apply(this, arguments));
|
100 | }
|
101 |
|
102 | _createClass(OneVsAllClassifier, [{
|
103 | key: 'createClassifier',
|
104 |
|
105 | |
106 |
|
107 |
|
108 |
|
109 |
|
110 |
|
111 |
|
112 |
|
113 | value: function createClassifier(classIndex) {
|
114 | throw new Error('Method must be implemented child class.');
|
115 | }
|
116 |
|
117 | |
118 |
|
119 |
|
120 |
|
121 |
|
122 |
|
123 | }, {
|
124 | key: 'createClassifiers',
|
125 | value: function createClassifiers(y) {
|
126 | var _this3 = this;
|
127 |
|
128 |
|
129 | var uniqueClassIndices = Arrays.unique(y);
|
130 |
|
131 |
|
132 | this.classifiers = uniqueClassIndices.map(function (classIndex) {
|
133 | var classifier = _this3.createClassifier();
|
134 |
|
135 | return {
|
136 | classIndex: classIndex,
|
137 | classifier: classifier
|
138 | };
|
139 | });
|
140 | }
|
141 |
|
142 | |
143 |
|
144 |
|
145 |
|
146 |
|
147 |
|
148 |
|
149 |
|
150 | }, {
|
151 | key: 'getClasses',
|
152 | value: function getClasses() {
|
153 | return this.classifiers.map(function (x, i) {
|
154 | return x;
|
155 | });
|
156 | }
|
157 |
|
158 | |
159 |
|
160 |
|
161 |
|
162 |
|
163 |
|
164 |
|
165 | }, {
|
166 | key: 'trainBatch',
|
167 | value: function trainBatch(X, y) {
|
168 | this.classifiers.forEach(function (classifier) {
|
169 | var yOneVsAll = y.map(function (classIndex) {
|
170 | return classIndex === classifier.classIndex ? 1 : 0;
|
171 | });
|
172 | classifier.classifier.train(X, yOneVsAll);
|
173 | });
|
174 | }
|
175 |
|
176 | |
177 |
|
178 |
|
179 |
|
180 |
|
181 |
|
182 |
|
183 | }, {
|
184 | key: 'trainIterative',
|
185 | value: function trainIterative() {
|
186 | var remainingClassIndices = Arrays.unique(this.training.labels);
|
187 |
|
188 | var epoch = 0;
|
189 |
|
190 | while (epoch < 100 && remainingClassIndices.length > 0) {
|
191 | var remainingClassIndicesNew = remainingClassIndices.slice();
|
192 |
|
193 |
|
194 | var _iteratorNormalCompletion = true;
|
195 | var _didIteratorError = false;
|
196 | var _iteratorError = undefined;
|
197 |
|
198 | try {
|
199 | for (var _iterator = remainingClassIndices[Symbol.iterator](), _step; !(_iteratorNormalCompletion = (_step = _iterator.next()).done); _iteratorNormalCompletion = true) {
|
200 | var classIndex = _step.value;
|
201 |
|
202 |
|
203 | this.classifiers[classIndex].trainIteration();
|
204 |
|
205 | if (this.classifiers[classIndex].checkConvergence()) {
|
206 | remainingClassIndicesNew.splice(remainingClassIndicesNew.indexOf(classIndex), 1);
|
207 | }
|
208 | }
|
209 | } catch (err) {
|
210 | _didIteratorError = true;
|
211 | _iteratorError = err;
|
212 | } finally {
|
213 | try {
|
214 | if (!_iteratorNormalCompletion && _iterator.return) {
|
215 | _iterator.return();
|
216 | }
|
217 | } finally {
|
218 | if (_didIteratorError) {
|
219 | throw _iteratorError;
|
220 | }
|
221 | }
|
222 | }
|
223 |
|
224 | remainingClassIndices = remainingClassIndicesNew;
|
225 |
|
226 |
|
227 | this.emit('iterationCompleted');
|
228 |
|
229 | epoch += 1;
|
230 | }
|
231 |
|
232 |
|
233 | this.emit('converged');
|
234 | }
|
235 |
|
236 | |
237 |
|
238 |
|
239 |
|
240 | }, {
|
241 | key: 'predict',
|
242 | value: function predict(X) {
|
243 |
|
244 |
|
245 | var datapointsPredictions = Arrays.transpose(this.classifiers.map(function (classifier) {
|
246 | return classifier.classifier.predict(X, { output: 'normalized' });
|
247 | }));
|
248 |
|
249 |
|
250 | return datapointsPredictions.map(function (x) {
|
251 | return Arrays.argMax(x);
|
252 | });
|
253 | }
|
254 |
|
255 | |
256 |
|
257 |
|
258 |
|
259 |
|
260 |
|
261 |
|
262 |
|
263 |
|
264 |
|
265 |
|
266 | }, {
|
267 | key: 'predictProba',
|
268 | value: function predictProba(X) {
|
269 | if (typeof this.classifiers[0].classifier.predictProba !== 'function') {
|
270 | throw new Error('Base classifier does not implement the predictProba method, which was attempted to be called from the one-vs-all classifier.');
|
271 | }
|
272 |
|
273 |
|
274 |
|
275 |
|
276 | var predictions = Arrays.transpose(this.classifiers.map(function (classifier) {
|
277 | return classifier.classifier.predictProba(X).map(function (probs) {
|
278 | return probs[1];
|
279 | });
|
280 | }));
|
281 |
|
282 |
|
283 | return predictions.map(function (x) {
|
284 | return Arrays.scale(x, 1 / Arrays.internalSum(x));
|
285 | });
|
286 | }
|
287 |
|
288 | |
289 |
|
290 |
|
291 |
|
292 |
|
293 |
|
294 |
|
295 | }, {
|
296 | key: 'getClassifiers',
|
297 | value: function getClassifiers() {
|
298 | return this.classifiers;
|
299 | }
|
300 | }]);
|
301 |
|
302 | return OneVsAllClassifier;
|
303 | }(Classifier); |
\ | No newline at end of file |