1 | import _ from 'lodash';
|
2 | import { reduceDecisionRules } from './reducer';
|
3 | import { tzFromOffset } from './time';
|
4 | import { CraftAiDecisionError, CraftAiNullDecisionError, CraftAiUnknownError } from './errors';
|
5 | import { formatDecisionRules, formatProperty } from './formatter';
|
6 | import isTimezone, { getTimezoneKey } from './timezones';
|
7 |
|
8 | const DECISION_FORMAT_VERSION = '2.0.0';
|
9 |
|
10 | const OPERATORS = {
|
11 | 'is': (context, value) => {
|
12 | if (_.isObject(context) && _.isObject(value)) {
|
13 | return _.isEmpty(context) && _.isEmpty(value);
|
14 | }
|
15 | else {
|
16 | return context === value;
|
17 | }
|
18 | },
|
19 | '>=': (context, value) => !_.isNull(context) && context * 1 >= value,
|
20 | '<': (context, value) => !_.isNull(context) && context * 1 < value,
|
21 | '[in[': (context, value) => {
|
22 | let context_val = context * 1;
|
23 | let from = value[0];
|
24 | let to = value[1];
|
25 |
|
26 | if (from < to) {
|
27 | return (!_.isNull(context) && context_val >= from && context_val < to);
|
28 | }
|
29 |
|
30 | else {
|
31 | return (!_.isNull(context) && (context_val >= from || context_val < to));
|
32 | }
|
33 | },
|
34 | 'in': (context, value) => value.indexOf(context) > -1
|
35 | };
|
36 |
|
37 | const VALUE_VALIDATOR = {
|
38 | continuous: (value) => _.isFinite(value),
|
39 | enum: (value) => _.isString(value),
|
40 | boolean: (value) => _.isBoolean(value),
|
41 | timezone: (value) => isTimezone(value),
|
42 | time_of_day: (value) => _.isFinite(value) && value >= 0 && value < 24,
|
43 | day_of_week: (value) => _.isInteger(value) && value >= 0 && value <= 6,
|
44 | day_of_month: (value) => _.isInteger(value) && value >= 1 && value <= 31,
|
45 | month_of_year: (value) => _.isInteger(value) && value >= 1 && value <= 12
|
46 | };
|
47 |
|
48 | function decideRecursion(node, context, configuration, outputType, outputValues) {
|
49 |
|
50 | if (!(node.children && node.children.length)) {
|
51 | const prediction = node.prediction;
|
52 | if (prediction.value == null) {
|
53 | return {
|
54 | predicted_value: undefined,
|
55 | confidence: undefined,
|
56 | decision_rules: [],
|
57 | error: {
|
58 | name: 'CraftAiNullDecisionError',
|
59 | message: 'Unable to take decision: the decision tree has no valid predicted value for the given context.'
|
60 | }
|
61 | };
|
62 | }
|
63 |
|
64 | let leafNode = {
|
65 | predicted_value: prediction.value,
|
66 | confidence: prediction.confidence || 0,
|
67 | decision_rules: [],
|
68 | nb_samples: prediction.nb_samples
|
69 | };
|
70 |
|
71 | if (!_.isUndefined(prediction.distribution.standard_deviation)) {
|
72 | leafNode.standard_deviation = prediction.distribution.standard_deviation;
|
73 | const min_value = prediction.distribution.min;
|
74 | const max_value = prediction.distribution.max;
|
75 | if (!_.isUndefined(min_value)) {
|
76 | leafNode.min = min_value;
|
77 | }
|
78 | if (!_.isUndefined(max_value)) {
|
79 | leafNode.max = max_value;
|
80 | }
|
81 | }
|
82 | else {
|
83 | leafNode.distribution = prediction.distribution;
|
84 | }
|
85 |
|
86 | return leafNode;
|
87 | }
|
88 |
|
89 |
|
90 | const matchingChild = _.find(
|
91 | node.children,
|
92 | (child) => {
|
93 | const decision_rule = child.decision_rule;
|
94 | const property = decision_rule.property;
|
95 | if (configuration.deactivate_missing_values && _.isNull(property)) {
|
96 | return {
|
97 | predicted_value: undefined,
|
98 | confidence: undefined,
|
99 | error: {
|
100 | name: 'CraftAiUnknownError',
|
101 | message: `Unable to take decision: property '${property}' is missing from the given context.`
|
102 | }
|
103 | };
|
104 | }
|
105 | return OPERATORS[decision_rule.operator](context[property], decision_rule.operand);
|
106 | }
|
107 | );
|
108 |
|
109 |
|
110 | if (matchingChild && matchingChild.error) {
|
111 | return matchingChild;
|
112 | }
|
113 |
|
114 | if (_.isUndefined(matchingChild)) {
|
115 | if (!configuration.deactivate_missing_values) {
|
116 | const { value, standard_deviation, size } = distribution(node);
|
117 | let finalResult = {};
|
118 |
|
119 |
|
120 |
|
121 | if (outputType === 'enum' || outputType === 'boolean') {
|
122 |
|
123 | let argmax
|
124 | = value
|
125 | .map((x, i) => [x, i])
|
126 | .reduce((r, a) => (a[0] > r[0] ? a : r))[1];
|
127 |
|
128 | const predicted_value = outputValues[argmax];
|
129 | finalResult = {
|
130 | predicted_value: predicted_value,
|
131 | distribution: value
|
132 | };
|
133 | }
|
134 | else {
|
135 | finalResult = {
|
136 | predicted_value: value,
|
137 | standard_deviation: standard_deviation
|
138 | };
|
139 | }
|
140 | return _.extend(finalResult, {
|
141 | confidence: null,
|
142 | decision_rules: [],
|
143 | nb_samples: size
|
144 | });
|
145 | }
|
146 | else {
|
147 |
|
148 | const operandList = _.uniq(_.map(_.values(node.children), (child) => child.decision_rule.operand));
|
149 | const property = _.head(node.children).decision_rule.property;
|
150 | return {
|
151 | predicted_value: undefined,
|
152 | confidence: undefined,
|
153 | decision_rules: [],
|
154 | error: {
|
155 | name: 'CraftAiNullDecisionError',
|
156 | message: `Unable to take decision: value '${context[property]}' for property '${property}' doesn't validate any of the decision rules.`,
|
157 | metadata: {
|
158 | property: property,
|
159 | value: context[property],
|
160 | expected_values: operandList
|
161 | }
|
162 | }
|
163 | };
|
164 | }
|
165 | }
|
166 |
|
167 | const result = decideRecursion(matchingChild, context, configuration, outputType, outputValues);
|
168 |
|
169 | let finalResult = _.extend(result, {
|
170 | decision_rules: [matchingChild.decision_rule].concat(result.decision_rules)
|
171 | });
|
172 |
|
173 | return finalResult;
|
174 | }
|
175 |
|
176 | function checkContext(configuration) {
|
177 |
|
178 | const expectedProperties = _.difference(
|
179 | _.keys(configuration.context),
|
180 | configuration.output
|
181 | );
|
182 |
|
183 |
|
184 | const validators = _.map(expectedProperties, (property) => {
|
185 | const otherValidator = () => {
|
186 | console.warn(`WARNING: "${configuration.context[property].type}" is not a supported type. Please refer to the documention to see what type you can use`);
|
187 | return true;
|
188 | };
|
189 | return {
|
190 | property,
|
191 | type: configuration.context[property].type,
|
192 | is_optional: configuration.context[property].is_optional,
|
193 | validator: VALUE_VALIDATOR[configuration.context[property].type] || otherValidator
|
194 | };
|
195 | });
|
196 |
|
197 | return (context) => {
|
198 | const { badProperties, missingProperties } = _.reduce(
|
199 | validators,
|
200 | ({ badProperties, missingProperties }, { property, type, is_optional, validator }) => {
|
201 | const value = context[property];
|
202 | const isNullAuthorized = _.isNull(value) && !configuration.deactivate_missing_values;
|
203 | const isOptionalAuthorized = _.isEmpty(value) && is_optional;
|
204 | if (value === undefined) {
|
205 | missingProperties.push(property);
|
206 | }
|
207 | else if (!validator(value) && !isNullAuthorized && !isOptionalAuthorized) {
|
208 | badProperties.push({ property, type, value });
|
209 | }
|
210 | return { badProperties, missingProperties };
|
211 | },
|
212 | { badProperties: [], missingProperties: [] }
|
213 | );
|
214 |
|
215 | if (missingProperties.length || badProperties.length) {
|
216 | const messages = _.concat(
|
217 | _.map(missingProperties, (property) => `expected property '${property}' is not defined`),
|
218 | _.map(badProperties, ({ property, type, value }) => `'${value}' is not a valid value for property '${property}' of type '${type}'`)
|
219 | );
|
220 | throw new CraftAiDecisionError({
|
221 | message: `Unable to take decision, the given context is not valid: ${messages.join(', ')}.`,
|
222 | metadata: _.assign({}, missingProperties.length && { missingProperties }, badProperties.length && { badProperties })
|
223 | });
|
224 | }
|
225 | };
|
226 | }
|
227 |
|
228 | export function distribution(node) {
|
229 | if (!(node.children && node.children.length)) {
|
230 |
|
231 |
|
232 |
|
233 | if (_.isArray(node.prediction.distribution)) {
|
234 | return {
|
235 | value: node.prediction.distribution,
|
236 | size: node.prediction.nb_samples
|
237 | };
|
238 | }
|
239 |
|
240 |
|
241 | return {
|
242 | value: node.prediction.value,
|
243 | standard_deviation: node.prediction.distribution.standard_deviation,
|
244 | size: node.prediction.nb_samples,
|
245 | min: node.prediction.distribution.min,
|
246 | max: node.prediction.distribution.max
|
247 | };
|
248 | }
|
249 |
|
250 |
|
251 |
|
252 | const { values, stds, sizes, mins, maxs } = _.map(node.children, (child) => distribution(child))
|
253 | .reduce((acc, { value, standard_deviation, size, min, max }) => {
|
254 | acc.values.push(value);
|
255 | acc.sizes.push(size);
|
256 | if (!_.isUndefined(standard_deviation)) {
|
257 | acc.stds.push(standard_deviation);
|
258 | acc.mins.push(min);
|
259 | acc.maxs.push(max);
|
260 | }
|
261 | return acc;
|
262 | }, {
|
263 | values: [],
|
264 | stds: [],
|
265 | sizes: [],
|
266 | mins: [],
|
267 | maxs: []
|
268 | });
|
269 |
|
270 | if (_.isArray(values[0])) {
|
271 | return computeMeanDistributions(values, sizes);
|
272 | }
|
273 | return computeMeanValues(values, sizes, stds, mins, maxs);
|
274 | }
|
275 |
|
276 | export function computeMeanValues(values, sizes, stds, mins, maxs) {
|
277 |
|
278 |
|
279 |
|
280 |
|
281 | if (_.isUndefined(stds)) {
|
282 | let totalSize = _.sum(sizes);
|
283 | const newMean =
|
284 | _.zip(values, sizes)
|
285 | .map(([mean, size]) => mean * (1.0 * size) / (1.0 * totalSize))
|
286 | .reduce(_.add);
|
287 | return {
|
288 | value: newMean,
|
289 | size: totalSize
|
290 | };
|
291 | }
|
292 |
|
293 |
|
294 | const { mean, variance, size, min, max } =
|
295 | _.zip(values, stds, sizes, mins, maxs)
|
296 | .map(([mean, std, size, min, max]) => {
|
297 | return {
|
298 | mean: mean,
|
299 | variance: std * std,
|
300 | size: size,
|
301 | min: min,
|
302 | max: max
|
303 | };
|
304 | })
|
305 | .reduce((acc, { mean, variance, size, min, max }) => {
|
306 | if (_.isUndefined(acc.mean)) {
|
307 | return {
|
308 | mean: mean,
|
309 | variance: variance,
|
310 | size: size,
|
311 | min: min,
|
312 | max: max
|
313 | };
|
314 | }
|
315 | const totalSize = 1.0 * (acc.size + size);
|
316 | if (!totalSize > 0.0) {
|
317 | return {
|
318 | mean: acc.mean,
|
319 | variance: acc.variance,
|
320 | size: acc.size
|
321 | };
|
322 | }
|
323 | const newVariance = (1.0 / (totalSize - 1)) * (
|
324 | (acc.size - 1) * acc.variance
|
325 | + (size - 1) * variance
|
326 | + ((acc.size * size) / totalSize) * (acc.mean - mean) * (acc.mean - mean)
|
327 | );
|
328 | const newMean = (1.0 / totalSize) * (acc.size * acc.mean + size * mean);
|
329 | const newMin = min < acc.min ? min : acc.min;
|
330 | const newMax = max > acc.max ? max : acc.max;
|
331 | return {
|
332 | mean: newMean,
|
333 | variance: newVariance,
|
334 | size: totalSize,
|
335 | min: newMin,
|
336 | max: newMax
|
337 | };
|
338 | }, {
|
339 | mean: undefined,
|
340 | variance: undefined,
|
341 | size: undefined,
|
342 | min: undefined,
|
343 | max: undefined
|
344 | });
|
345 | return {
|
346 | value: mean,
|
347 | standard_deviation: Math.sqrt(variance),
|
348 | size: size,
|
349 | min: min,
|
350 | max: max
|
351 | };
|
352 | }
|
353 |
|
354 | export function computeMeanDistributions(values, sizes) {
|
355 |
|
356 |
|
357 |
|
358 | let totalSize = _.sum(sizes);
|
359 | let multiplyByBranchRatio =
|
360 | _.zip(values, sizes)
|
361 | .map((zipped) => _.map(zipped[0], (val) => val * zipped[1] / totalSize));
|
362 | let sumArrays = _sumArrays(multiplyByBranchRatio);
|
363 | return { value: sumArrays, size: totalSize };
|
364 | }
|
365 |
|
366 | function _sumArrays(arrays) {
|
367 | return _.reduce(arrays, (acc_sum, array) =>
|
368 | _.map(array, (val, i) => (acc_sum[i] || 0.) + val)
|
369 | , new Array(arrays[0].length));
|
370 | }
|
371 |
|
372 | function decide(configuration, trees, context) {
|
373 | checkContext(configuration)(context);
|
374 |
|
375 |
|
376 | const timezoneProperty = getTimezoneKey(configuration.context);
|
377 | if (!_.isUndefined(timezoneProperty)) {
|
378 | context[timezoneProperty] = tzFromOffset(context[timezoneProperty]);
|
379 | }
|
380 | return {
|
381 | _version: DECISION_FORMAT_VERSION,
|
382 | context,
|
383 | output: _.assign(..._.map(configuration.output, (output) => {
|
384 | const outputType = configuration.context[output].type;
|
385 | let decision = decideRecursion(trees[output], context, configuration, outputType, trees[output].output_values);
|
386 | if (decision.error) {
|
387 | switch (decision.error.name) {
|
388 | case 'CraftAiNullDecisionError':
|
389 | throw new CraftAiNullDecisionError({
|
390 | message: decision.error.message,
|
391 | metadata: _.extend(decision.error.metadata, {
|
392 | decision_rules: decision.decision_rules
|
393 | })
|
394 | });
|
395 | default:
|
396 | throw new CraftAiUnknownError({
|
397 | message: decision.error.message
|
398 | });
|
399 | }
|
400 | }
|
401 | return {
|
402 | [output]: decision
|
403 | };
|
404 | }))
|
405 | };
|
406 | }
|
407 |
|
408 | export { formatDecisionRules, formatProperty, reduceDecisionRules, decide };
|