UNPKG

18.3 kBJavaScriptView Raw
1'use strict';
2
3Object.defineProperty(exports, "__esModule", {
4 value: true
5});
6exports.decide = exports.reduceDecisionRules = exports.formatProperty = exports.formatDecisionRules = undefined;
7
8var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }();
9
10exports.distribution = distribution;
11exports.computeMeanValues = computeMeanValues;
12exports.computeMeanDistributions = computeMeanDistributions;
13
14var _lodash = require('lodash');
15
16var _lodash2 = _interopRequireDefault(_lodash);
17
18var _reducer = require('./reducer');
19
20var _time = require('./time');
21
22var _errors = require('./errors');
23
24var _formatter = require('./formatter');
25
26var _timezones = require('./timezones');
27
28var _timezones2 = _interopRequireDefault(_timezones);
29
30function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
31
32function _defineProperty(obj, key, value) { if (key in obj) { Object.defineProperty(obj, key, { value: value, enumerable: true, configurable: true, writable: true }); } else { obj[key] = value; } return obj; }
33
34function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i = 0, arr2 = Array(arr.length); i < arr.length; i++) { arr2[i] = arr[i]; } return arr2; } else { return Array.from(arr); } }
35
36var DECISION_FORMAT_VERSION = '2.0.0';
37
38var OPERATORS = {
39 'is': function is(context, value) {
40 if (_lodash2.default.isObject(context) && _lodash2.default.isObject(value)) {
41 return _lodash2.default.isEmpty(context) && _lodash2.default.isEmpty(value);
42 } else {
43 return context === value;
44 }
45 },
46 '>=': function _(context, value) {
47 return !_lodash2.default.isNull(context) && context * 1 >= value;
48 },
49 '<': function _(context, value) {
50 return !_lodash2.default.isNull(context) && context * 1 < value;
51 },
52 '[in[': function _in(context, value) {
53 var context_val = context * 1;
54 var from = value[0];
55 var to = value[1];
56 //the interval is not looping
57 if (from < to) {
58 return !_lodash2.default.isNull(context) && context_val >= from && context_val < to;
59 }
60 //the interval IS looping
61 else {
62 return !_lodash2.default.isNull(context) && (context_val >= from || context_val < to);
63 }
64 },
65 'in': function _in(context, value) {
66 return value.indexOf(context) > -1;
67 }
68};
69
70var VALUE_VALIDATOR = {
71 continuous: function continuous(value) {
72 return _lodash2.default.isFinite(value);
73 },
74 enum: function _enum(value) {
75 return _lodash2.default.isString(value);
76 },
77 boolean: function boolean(value) {
78 return _lodash2.default.isBoolean(value);
79 },
80 timezone: function timezone(value) {
81 return (0, _timezones2.default)(value);
82 },
83 time_of_day: function time_of_day(value) {
84 return _lodash2.default.isFinite(value) && value >= 0 && value < 24;
85 },
86 day_of_week: function day_of_week(value) {
87 return _lodash2.default.isInteger(value) && value >= 0 && value <= 6;
88 },
89 day_of_month: function day_of_month(value) {
90 return _lodash2.default.isInteger(value) && value >= 1 && value <= 31;
91 },
92 month_of_year: function month_of_year(value) {
93 return _lodash2.default.isInteger(value) && value >= 1 && value <= 12;
94 }
95};
96
97function decideRecursion(node, context, configuration, outputType, outputValues) {
98 // Leaf
99 if (!(node.children && node.children.length)) {
100 var prediction = node.prediction;
101 if (prediction.value == null) {
102 return {
103 predicted_value: undefined,
104 confidence: undefined,
105 decision_rules: [],
106 error: {
107 name: 'CraftAiNullDecisionError',
108 message: 'Unable to take decision: the decision tree has no valid predicted value for the given context.'
109 }
110 };
111 }
112
113 var leafNode = {
114 predicted_value: prediction.value,
115 confidence: prediction.confidence || 0,
116 decision_rules: [],
117 nb_samples: prediction.nb_samples
118 };
119
120 if (!_lodash2.default.isUndefined(prediction.distribution.standard_deviation)) {
121 leafNode.standard_deviation = prediction.distribution.standard_deviation;
122 var min_value = prediction.distribution.min;
123 var max_value = prediction.distribution.max;
124 if (!_lodash2.default.isUndefined(min_value)) {
125 leafNode.min = min_value;
126 }
127 if (!_lodash2.default.isUndefined(max_value)) {
128 leafNode.max = max_value;
129 }
130 } else {
131 leafNode.distribution = prediction.distribution;
132 }
133
134 return leafNode;
135 }
136
137 // Regular node
138 var matchingChild = _lodash2.default.find(node.children, function (child) {
139 var decision_rule = child.decision_rule;
140 var property = decision_rule.property;
141 if (configuration.deactivate_missing_values && _lodash2.default.isNull(property)) {
142 return {
143 predicted_value: undefined,
144 confidence: undefined,
145 error: {
146 name: 'CraftAiUnknownError',
147 message: 'Unable to take decision: property \'' + property + '\' is missing from the given context.'
148 }
149 };
150 }
151 return OPERATORS[decision_rule.operator](context[property], decision_rule.operand);
152 });
153
154 // matching child property error
155 if (matchingChild && matchingChild.error) {
156 return matchingChild;
157 }
158
159 if (_lodash2.default.isUndefined(matchingChild)) {
160 if (!configuration.deactivate_missing_values) {
161 var _distribution = distribution(node),
162 value = _distribution.value,
163 standard_deviation = _distribution.standard_deviation,
164 size = _distribution.size;
165
166 var _finalResult = {};
167 // If it is a classification problem we return the class with the highest
168 // probability. Otherwise, if the current output type is continuous/periodic
169 // then the returned value corresponds to the subtree weighted output values.
170 if (outputType === 'enum' || outputType === 'boolean') {
171 // Compute the argmax function on the returned distribution:
172 var argmax = value.map(function (x, i) {
173 return [x, i];
174 }).reduce(function (r, a) {
175 return a[0] > r[0] ? a : r;
176 })[1];
177
178 var predicted_value = outputValues[argmax];
179 _finalResult = {
180 predicted_value: predicted_value,
181 distribution: value
182 };
183 } else {
184 _finalResult = {
185 predicted_value: value,
186 standard_deviation: standard_deviation
187 };
188 }
189 return _lodash2.default.extend(_finalResult, {
190 confidence: null,
191 decision_rules: [],
192 nb_samples: size
193 });
194 } else {
195 // Should only happens when an unexpected value for an enum is encountered
196 var operandList = _lodash2.default.uniq(_lodash2.default.map(_lodash2.default.values(node.children), function (child) {
197 return child.decision_rule.operand;
198 }));
199 var property = _lodash2.default.head(node.children).decision_rule.property;
200 return {
201 predicted_value: undefined,
202 confidence: undefined,
203 decision_rules: [],
204 error: {
205 name: 'CraftAiNullDecisionError',
206 message: 'Unable to take decision: value \'' + context[property] + '\' for property \'' + property + '\' doesn\'t validate any of the decision rules.',
207 metadata: {
208 property: property,
209 value: context[property],
210 expected_values: operandList
211 }
212 }
213 };
214 }
215 }
216 // matching child found: recurse !
217 var result = decideRecursion(matchingChild, context, configuration, outputType, outputValues);
218
219 var finalResult = _lodash2.default.extend(result, {
220 decision_rules: [matchingChild.decision_rule].concat(result.decision_rules)
221 });
222
223 return finalResult;
224}
225
226function checkContext(configuration) {
227 // Extract the required properties (i.e. those that are not the output)
228 var expectedProperties = _lodash2.default.difference(_lodash2.default.keys(configuration.context), configuration.output);
229
230 // Build a context validator
231 var validators = _lodash2.default.map(expectedProperties, function (property) {
232 var otherValidator = function otherValidator() {
233 console.warn('WARNING: "' + configuration.context[property].type + '" is not a supported type. Please refer to the documention to see what type you can use');
234 return true;
235 };
236 return {
237 property: property,
238 type: configuration.context[property].type,
239 is_optional: configuration.context[property].is_optional,
240 validator: VALUE_VALIDATOR[configuration.context[property].type] || otherValidator
241 };
242 });
243
244 return function (context) {
245 var _$reduce = _lodash2.default.reduce(validators, function (_ref, _ref2) {
246 var badProperties = _ref.badProperties,
247 missingProperties = _ref.missingProperties;
248 var property = _ref2.property,
249 type = _ref2.type,
250 is_optional = _ref2.is_optional,
251 validator = _ref2.validator;
252
253 var value = context[property];
254 var isNullAuthorized = _lodash2.default.isNull(value) && !configuration.deactivate_missing_values;
255 var isOptionalAuthorized = _lodash2.default.isEmpty(value) && is_optional;
256 if (value === undefined) {
257 missingProperties.push(property);
258 } else if (!validator(value) && !isNullAuthorized && !isOptionalAuthorized) {
259 badProperties.push({ property: property, type: type, value: value });
260 }
261 return { badProperties: badProperties, missingProperties: missingProperties };
262 }, { badProperties: [], missingProperties: [] }),
263 badProperties = _$reduce.badProperties,
264 missingProperties = _$reduce.missingProperties;
265
266 if (missingProperties.length || badProperties.length) {
267 var messages = _lodash2.default.concat(_lodash2.default.map(missingProperties, function (property) {
268 return 'expected property \'' + property + '\' is not defined';
269 }), _lodash2.default.map(badProperties, function (_ref3) {
270 var property = _ref3.property,
271 type = _ref3.type,
272 value = _ref3.value;
273 return '\'' + value + '\' is not a valid value for property \'' + property + '\' of type \'' + type + '\'';
274 }));
275 throw new _errors.CraftAiDecisionError({
276 message: 'Unable to take decision, the given context is not valid: ' + messages.join(', ') + '.',
277 metadata: _lodash2.default.assign({}, missingProperties.length && { missingProperties: missingProperties }, badProperties.length && { badProperties: badProperties })
278 });
279 }
280 };
281}
282
283function distribution(node) {
284 if (!(node.children && node.children.length)) {
285 // If the distribution attribute is an array it means that it is
286 // a classification problem. We therefore compute the distribution of
287 // the classes in this leaf and return the branch size.
288 if (_lodash2.default.isArray(node.prediction.distribution)) {
289 return {
290 value: node.prediction.distribution,
291 size: node.prediction.nb_samples
292 };
293 }
294 // Otherwise it is a regression problem, and we return the mean value
295 // of the leaf, the standard_deviation and the branch size.
296 return {
297 value: node.prediction.value,
298 standard_deviation: node.prediction.distribution.standard_deviation,
299 size: node.prediction.nb_samples,
300 min: node.prediction.distribution.min,
301 max: node.prediction.distribution.max
302 };
303 }
304
305 // If it is not a leaf, we recurse into the children and store the distributions
306 // and sizes of each child branch.
307
308 var _$map$reduce = _lodash2.default.map(node.children, function (child) {
309 return distribution(child);
310 }).reduce(function (acc, _ref4) {
311 var value = _ref4.value,
312 standard_deviation = _ref4.standard_deviation,
313 size = _ref4.size,
314 min = _ref4.min,
315 max = _ref4.max;
316
317 acc.values.push(value);
318 acc.sizes.push(size);
319 if (!_lodash2.default.isUndefined(standard_deviation)) {
320 acc.stds.push(standard_deviation);
321 acc.mins.push(min);
322 acc.maxs.push(max);
323 }
324 return acc;
325 }, {
326 values: [],
327 stds: [],
328 sizes: [],
329 mins: [],
330 maxs: []
331 }),
332 values = _$map$reduce.values,
333 stds = _$map$reduce.stds,
334 sizes = _$map$reduce.sizes,
335 mins = _$map$reduce.mins,
336 maxs = _$map$reduce.maxs;
337
338 if (_lodash2.default.isArray(values[0])) {
339 return computeMeanDistributions(values, sizes);
340 }
341 return computeMeanValues(values, sizes, stds, mins, maxs);
342}
343
344function computeMeanValues(values, sizes, stds, mins, maxs) {
345 // Compute the weighted mean of the given array of values.
346 // Example, for values = [ 4, 3, 6 ], sizes = [1, 2, 1]
347 // This function computes (4*1 + 3*2 + 1*6) / (1+2+1) = 16/4 = 4
348 // If no standard deviation array is given, use classical weighted mean formula:
349 if (_lodash2.default.isUndefined(stds)) {
350 var totalSize = _lodash2.default.sum(sizes);
351 var newMean = _lodash2.default.zip(values, sizes).map(function (_ref5) {
352 var _ref6 = _slicedToArray(_ref5, 2),
353 mean = _ref6[0],
354 size = _ref6[1];
355
356 return mean * (1.0 * size) / (1.0 * totalSize);
357 }).reduce(_lodash2.default.add);
358 return {
359 value: newMean,
360 size: totalSize
361 };
362 }
363 // Otherwise, to compute the weighted standard deviation the following formula is used:
364 // https://math.stackexchange.com/questions/2238086/calculate-variance-of-a-subset
365
366 var _$zip$map$reduce = _lodash2.default.zip(values, stds, sizes, mins, maxs).map(function (_ref7) {
367 var _ref8 = _slicedToArray(_ref7, 5),
368 mean = _ref8[0],
369 std = _ref8[1],
370 size = _ref8[2],
371 min = _ref8[3],
372 max = _ref8[4];
373
374 return {
375 mean: mean,
376 variance: std * std,
377 size: size,
378 min: min,
379 max: max
380 };
381 }).reduce(function (acc, _ref9) {
382 var mean = _ref9.mean,
383 variance = _ref9.variance,
384 size = _ref9.size,
385 min = _ref9.min,
386 max = _ref9.max;
387
388 if (_lodash2.default.isUndefined(acc.mean)) {
389 return {
390 mean: mean,
391 variance: variance,
392 size: size,
393 min: min,
394 max: max
395 };
396 }
397 var totalSize = 1.0 * (acc.size + size);
398 if (!totalSize > 0.0) {
399 return {
400 mean: acc.mean,
401 variance: acc.variance,
402 size: acc.size
403 };
404 }
405 var newVariance = 1.0 / (totalSize - 1) * ((acc.size - 1) * acc.variance + (size - 1) * variance + acc.size * size / totalSize * (acc.mean - mean) * (acc.mean - mean));
406 var newMean = 1.0 / totalSize * (acc.size * acc.mean + size * mean);
407 var newMin = min < acc.min ? min : acc.min;
408 var newMax = max > acc.max ? max : acc.max;
409 return {
410 mean: newMean,
411 variance: newVariance,
412 size: totalSize,
413 min: newMin,
414 max: newMax
415 };
416 }, {
417 mean: undefined,
418 variance: undefined,
419 size: undefined,
420 min: undefined,
421 max: undefined
422 }),
423 mean = _$zip$map$reduce.mean,
424 variance = _$zip$map$reduce.variance,
425 size = _$zip$map$reduce.size,
426 min = _$zip$map$reduce.min,
427 max = _$zip$map$reduce.max;
428
429 return {
430 value: mean,
431 standard_deviation: Math.sqrt(variance),
432 size: size,
433 min: min,
434 max: max
435 };
436}
437
438function computeMeanDistributions(values, sizes) {
439 // Compute the weighted mean of the given array of distributions (array of probabilities).
440 // Example, for values = [[ 4, 3, 6 ], [1, 2, 3], [3, 4, 5]], sizes = [1, 2, 1]
441 // This function computes ([ 4, 3, 6]*1 + [1, 2, 3]*2 + [3, 4, 5]*6) / (1+2+1) = ...
442 var totalSize = _lodash2.default.sum(sizes);
443 var multiplyByBranchRatio = _lodash2.default.zip(values, sizes).map(function (zipped) {
444 return _lodash2.default.map(zipped[0], function (val) {
445 return val * zipped[1] / totalSize;
446 });
447 });
448 var sumArrays = _sumArrays(multiplyByBranchRatio);
449 return { value: sumArrays, size: totalSize };
450}
451
452function _sumArrays(arrays) {
453 return _lodash2.default.reduce(arrays, function (acc_sum, array) {
454 return _lodash2.default.map(array, function (val, i) {
455 return (acc_sum[i] || 0.) + val;
456 });
457 }, new Array(arrays[0].length));
458}
459
460function decide(configuration, trees, context) {
461 checkContext(configuration)(context);
462 // Convert timezones as integers to the standard +/-hh:mm format
463 // This should only happen when no Time() object is passed to the interpreter
464 var timezoneProperty = (0, _timezones.getTimezoneKey)(configuration.context);
465 if (!_lodash2.default.isUndefined(timezoneProperty)) {
466 context[timezoneProperty] = (0, _time.tzFromOffset)(context[timezoneProperty]);
467 }
468 return {
469 _version: DECISION_FORMAT_VERSION,
470 context: context,
471 output: _lodash2.default.assign.apply(_lodash2.default, _toConsumableArray(_lodash2.default.map(configuration.output, function (output) {
472 var outputType = configuration.context[output].type;
473 var decision = decideRecursion(trees[output], context, configuration, outputType, trees[output].output_values);
474 if (decision.error) {
475 switch (decision.error.name) {
476 case 'CraftAiNullDecisionError':
477 throw new _errors.CraftAiNullDecisionError({
478 message: decision.error.message,
479 metadata: _lodash2.default.extend(decision.error.metadata, {
480 decision_rules: decision.decision_rules
481 })
482 });
483 default:
484 throw new _errors.CraftAiUnknownError({
485 message: decision.error.message
486 });
487 }
488 }
489 return _defineProperty({}, output, decision);
490 })))
491 };
492}
493
494exports.formatDecisionRules = _formatter.formatDecisionRules;
495exports.formatProperty = _formatter.formatProperty;
496exports.reduceDecisionRules = _reducer.reduceDecisionRules;
497exports.decide = decide;
\No newline at end of file