UNPKG

14.2 kBJavaScriptView Raw
1import _ from 'lodash';
2import { reduceDecisionRules } from './reducer';
3import { tzFromOffset } from './time';
4import { CraftAiDecisionError, CraftAiNullDecisionError, CraftAiUnknownError } from './errors';
5import { formatDecisionRules, formatProperty } from './formatter';
6import isTimezone, { getTimezoneKey } from './timezones';
7
8const DECISION_FORMAT_VERSION = '2.0.0';
9
10const OPERATORS = {
11 'is': (context, value) => {
12 if (_.isObject(context) && _.isObject(value)) {
13 return _.isEmpty(context) && _.isEmpty(value);
14 }
15 else {
16 return context === value;
17 }
18 },
19 '>=': (context, value) => !_.isNull(context) && context * 1 >= value,
20 '<': (context, value) => !_.isNull(context) && context * 1 < value,
21 '[in[': (context, value) => {
22 let context_val = context * 1;
23 let from = value[0];
24 let to = value[1];
25 //the interval is not looping
26 if (from < to) {
27 return (!_.isNull(context) && context_val >= from && context_val < to);
28 }
29 //the interval IS looping
30 else {
31 return (!_.isNull(context) && (context_val >= from || context_val < to));
32 }
33 },
34 'in': (context, value) => value.indexOf(context) > -1
35};
36
37const VALUE_VALIDATOR = {
38 continuous: (value) => _.isFinite(value),
39 enum: (value) => _.isString(value),
40 boolean: (value) => _.isBoolean(value),
41 timezone: (value) => isTimezone(value),
42 time_of_day: (value) => _.isFinite(value) && value >= 0 && value < 24,
43 day_of_week: (value) => _.isInteger(value) && value >= 0 && value <= 6,
44 day_of_month: (value) => _.isInteger(value) && value >= 1 && value <= 31,
45 month_of_year: (value) => _.isInteger(value) && value >= 1 && value <= 12
46};
47
48function decideRecursion(node, context, configuration, outputType, outputValues) {
49 // Leaf
50 if (!(node.children && node.children.length)) {
51 const prediction = node.prediction;
52 if (prediction.value == null) {
53 return {
54 predicted_value: undefined,
55 confidence: undefined,
56 decision_rules: [],
57 error: {
58 name: 'CraftAiNullDecisionError',
59 message: 'Unable to take decision: the decision tree has no valid predicted value for the given context.'
60 }
61 };
62 }
63
64 let leafNode = {
65 predicted_value: prediction.value,
66 confidence: prediction.confidence || 0,
67 decision_rules: [],
68 nb_samples: prediction.nb_samples
69 };
70
71 if (!_.isUndefined(prediction.distribution.standard_deviation)) {
72 leafNode.standard_deviation = prediction.distribution.standard_deviation;
73 const min_value = prediction.distribution.min;
74 const max_value = prediction.distribution.max;
75 if (!_.isUndefined(min_value)) {
76 leafNode.min = min_value;
77 }
78 if (!_.isUndefined(max_value)) {
79 leafNode.max = max_value;
80 }
81 }
82 else {
83 leafNode.distribution = prediction.distribution;
84 }
85
86 return leafNode;
87 }
88
89 // Regular node
90 const matchingChild = _.find(
91 node.children,
92 (child) => {
93 const decision_rule = child.decision_rule;
94 const property = decision_rule.property;
95 if (configuration.deactivate_missing_values && _.isNull(property)) {
96 return {
97 predicted_value: undefined,
98 confidence: undefined,
99 error: {
100 name: 'CraftAiUnknownError',
101 message: `Unable to take decision: property '${property}' is missing from the given context.`
102 }
103 };
104 }
105 return OPERATORS[decision_rule.operator](context[property], decision_rule.operand);
106 }
107 );
108
109 // matching child property error
110 if (matchingChild && matchingChild.error) {
111 return matchingChild;
112 }
113
114 if (_.isUndefined(matchingChild)) {
115 if (!configuration.deactivate_missing_values) {
116 const { value, standard_deviation, size } = distribution(node);
117 let finalResult = {};
118 // If it is a classification problem we return the class with the highest
119 // probability. Otherwise, if the current output type is continuous/periodic
120 // then the returned value corresponds to the subtree weighted output values.
121 if (outputType === 'enum' || outputType === 'boolean') {
122 // Compute the argmax function on the returned distribution:
123 let argmax
124 = value
125 .map((x, i) => [x, i])
126 .reduce((r, a) => (a[0] > r[0] ? a : r))[1];
127
128 const predicted_value = outputValues[argmax];
129 finalResult = {
130 predicted_value: predicted_value,
131 distribution: value
132 };
133 }
134 else {
135 finalResult = {
136 predicted_value: value,
137 standard_deviation: standard_deviation
138 };
139 }
140 return _.extend(finalResult, {
141 confidence: null,
142 decision_rules: [],
143 nb_samples: size
144 });
145 }
146 else {
147 // Should only happens when an unexpected value for an enum is encountered
148 const operandList = _.uniq(_.map(_.values(node.children), (child) => child.decision_rule.operand));
149 const property = _.head(node.children).decision_rule.property;
150 return {
151 predicted_value: undefined,
152 confidence: undefined,
153 decision_rules: [],
154 error: {
155 name: 'CraftAiNullDecisionError',
156 message: `Unable to take decision: value '${context[property]}' for property '${property}' doesn't validate any of the decision rules.`,
157 metadata: {
158 property: property,
159 value: context[property],
160 expected_values: operandList
161 }
162 }
163 };
164 }
165 }
166 // matching child found: recurse !
167 const result = decideRecursion(matchingChild, context, configuration, outputType, outputValues);
168
169 let finalResult = _.extend(result, {
170 decision_rules: [matchingChild.decision_rule].concat(result.decision_rules)
171 });
172
173 return finalResult;
174}
175
176function checkContext(configuration) {
177 // Extract the required properties (i.e. those that are not the output)
178 const expectedProperties = _.difference(
179 _.keys(configuration.context),
180 configuration.output
181 );
182
183 // Build a context validator
184 const validators = _.map(expectedProperties, (property) => {
185 const otherValidator = () => {
186 console.warn(`WARNING: "${configuration.context[property].type}" is not a supported type. Please refer to the documention to see what type you can use`);
187 return true;
188 };
189 return {
190 property,
191 type: configuration.context[property].type,
192 is_optional: configuration.context[property].is_optional,
193 validator: VALUE_VALIDATOR[configuration.context[property].type] || otherValidator
194 };
195 });
196
197 return (context) => {
198 const { badProperties, missingProperties } = _.reduce(
199 validators,
200 ({ badProperties, missingProperties }, { property, type, is_optional, validator }) => {
201 const value = context[property];
202 const isNullAuthorized = _.isNull(value) && !configuration.deactivate_missing_values;
203 const isOptionalAuthorized = _.isEmpty(value) && is_optional;
204 if (value === undefined) {
205 missingProperties.push(property);
206 }
207 else if (!validator(value) && !isNullAuthorized && !isOptionalAuthorized) {
208 badProperties.push({ property, type, value });
209 }
210 return { badProperties, missingProperties };
211 },
212 { badProperties: [], missingProperties: [] }
213 );
214
215 if (missingProperties.length || badProperties.length) {
216 const messages = _.concat(
217 _.map(missingProperties, (property) => `expected property '${property}' is not defined`),
218 _.map(badProperties, ({ property, type, value }) => `'${value}' is not a valid value for property '${property}' of type '${type}'`)
219 );
220 throw new CraftAiDecisionError({
221 message: `Unable to take decision, the given context is not valid: ${messages.join(', ')}.`,
222 metadata: _.assign({}, missingProperties.length && { missingProperties }, badProperties.length && { badProperties })
223 });
224 }
225 };
226}
227
228export function distribution(node) {
229 if (!(node.children && node.children.length)) {
230 // If the distribution attribute is an array it means that it is
231 // a classification problem. We therefore compute the distribution of
232 // the classes in this leaf and return the branch size.
233 if (_.isArray(node.prediction.distribution)) {
234 return {
235 value: node.prediction.distribution,
236 size: node.prediction.nb_samples
237 };
238 }
239 // Otherwise it is a regression problem, and we return the mean value
240 // of the leaf, the standard_deviation and the branch size.
241 return {
242 value: node.prediction.value,
243 standard_deviation: node.prediction.distribution.standard_deviation,
244 size: node.prediction.nb_samples,
245 min: node.prediction.distribution.min,
246 max: node.prediction.distribution.max
247 };
248 }
249
250 // If it is not a leaf, we recurse into the children and store the distributions
251 // and sizes of each child branch.
252 const { values, stds, sizes, mins, maxs } = _.map(node.children, (child) => distribution(child))
253 .reduce((acc, { value, standard_deviation, size, min, max }) => {
254 acc.values.push(value);
255 acc.sizes.push(size);
256 if (!_.isUndefined(standard_deviation)) {
257 acc.stds.push(standard_deviation);
258 acc.mins.push(min);
259 acc.maxs.push(max);
260 }
261 return acc;
262 }, {
263 values: [],
264 stds: [],
265 sizes: [],
266 mins: [],
267 maxs: []
268 });
269
270 if (_.isArray(values[0])) {
271 return computeMeanDistributions(values, sizes);
272 }
273 return computeMeanValues(values, sizes, stds, mins, maxs);
274}
275
276export function computeMeanValues(values, sizes, stds, mins, maxs) {
277 // Compute the weighted mean of the given array of values.
278 // Example, for values = [ 4, 3, 6 ], sizes = [1, 2, 1]
279 // This function computes (4*1 + 3*2 + 1*6) / (1+2+1) = 16/4 = 4
280 // If no standard deviation array is given, use classical weighted mean formula:
281 if (_.isUndefined(stds)) {
282 let totalSize = _.sum(sizes);
283 const newMean =
284 _.zip(values, sizes)
285 .map(([mean, size]) => mean * (1.0 * size) / (1.0 * totalSize))
286 .reduce(_.add);
287 return {
288 value: newMean,
289 size: totalSize
290 };
291 }
292 // Otherwise, to compute the weighted standard deviation the following formula is used:
293 // https://math.stackexchange.com/questions/2238086/calculate-variance-of-a-subset
294 const { mean, variance, size, min, max } =
295 _.zip(values, stds, sizes, mins, maxs)
296 .map(([mean, std, size, min, max]) => {
297 return {
298 mean: mean,
299 variance: std * std,
300 size: size,
301 min: min,
302 max: max
303 };
304 })
305 .reduce((acc, { mean, variance, size, min, max }) => {
306 if (_.isUndefined(acc.mean)) {
307 return {
308 mean: mean,
309 variance: variance,
310 size: size,
311 min: min,
312 max: max
313 };
314 }
315 const totalSize = 1.0 * (acc.size + size);
316 if (!totalSize > 0.0) {
317 return {
318 mean: acc.mean,
319 variance: acc.variance,
320 size: acc.size
321 };
322 }
323 const newVariance = (1.0 / (totalSize - 1)) * (
324 (acc.size - 1) * acc.variance
325 + (size - 1) * variance
326 + ((acc.size * size) / totalSize) * (acc.mean - mean) * (acc.mean - mean)
327 );
328 const newMean = (1.0 / totalSize) * (acc.size * acc.mean + size * mean);
329 const newMin = min < acc.min ? min : acc.min;
330 const newMax = max > acc.max ? max : acc.max;
331 return {
332 mean: newMean,
333 variance: newVariance,
334 size: totalSize,
335 min: newMin,
336 max: newMax
337 };
338 }, {
339 mean: undefined,
340 variance: undefined,
341 size: undefined,
342 min: undefined,
343 max: undefined
344 });
345 return {
346 value: mean,
347 standard_deviation: Math.sqrt(variance),
348 size: size,
349 min: min,
350 max: max
351 };
352}
353
354export function computeMeanDistributions(values, sizes) {
355 // Compute the weighted mean of the given array of distributions (array of probabilities).
356 // Example, for values = [[ 4, 3, 6 ], [1, 2, 3], [3, 4, 5]], sizes = [1, 2, 1]
357 // This function computes ([ 4, 3, 6]*1 + [1, 2, 3]*2 + [3, 4, 5]*6) / (1+2+1) = ...
358 let totalSize = _.sum(sizes);
359 let multiplyByBranchRatio =
360 _.zip(values, sizes)
361 .map((zipped) => _.map(zipped[0], (val) => val * zipped[1] / totalSize));
362 let sumArrays = _sumArrays(multiplyByBranchRatio);
363 return { value: sumArrays, size: totalSize };
364}
365
366function _sumArrays(arrays) {
367 return _.reduce(arrays, (acc_sum, array) =>
368 _.map(array, (val, i) => (acc_sum[i] || 0.) + val)
369 , new Array(arrays[0].length));
370}
371
372function decide(configuration, trees, context) {
373 checkContext(configuration)(context);
374 // Convert timezones as integers to the standard +/-hh:mm format
375 // This should only happen when no Time() object is passed to the interpreter
376 const timezoneProperty = getTimezoneKey(configuration.context);
377 if (!_.isUndefined(timezoneProperty)) {
378 context[timezoneProperty] = tzFromOffset(context[timezoneProperty]);
379 }
380 return {
381 _version: DECISION_FORMAT_VERSION,
382 context,
383 output: _.assign(..._.map(configuration.output, (output) => {
384 const outputType = configuration.context[output].type;
385 let decision = decideRecursion(trees[output], context, configuration, outputType, trees[output].output_values);
386 if (decision.error) {
387 switch (decision.error.name) {
388 case 'CraftAiNullDecisionError':
389 throw new CraftAiNullDecisionError({
390 message: decision.error.message,
391 metadata: _.extend(decision.error.metadata, {
392 decision_rules: decision.decision_rules
393 })
394 });
395 default:
396 throw new CraftAiUnknownError({
397 message: decision.error.message
398 });
399 }
400 }
401 return {
402 [output]: decision
403 };
404 }))
405 };
406}
407
408export { formatDecisionRules, formatProperty, reduceDecisionRules, decide };