UNPKG

2.39 kBJavaScriptView Raw
1'use strict'
2
3const deepForEach = require('../../utils/collection/deepForEach')
4const reduce = require('../../utils/collection/reduce')
5const containsCollections = require('../../utils/collection/containsCollections')
6
7function factory (type, config, load, typed) {
8 const add = load(require('../arithmetic/addScalar'))
9 const improveErrorMessage = load(require('./utils/improveErrorMessage'))
10
11 /**
12 * Compute the sum of a matrix or a list with values.
13 * In case of a (multi dimensional) array or matrix, the sum of all
14 * elements will be calculated.
15 *
16 * Syntax:
17 *
18 * math.sum(a, b, c, ...)
19 * math.sum(A)
20 *
21 * Examples:
22 *
23 * math.sum(2, 1, 4, 3) // returns 10
24 * math.sum([2, 1, 4, 3]) // returns 10
25 * math.sum([[2, 5], [4, 3], [1, 7]]) // returns 22
26 *
27 * See also:
28 *
29 * mean, median, min, max, prod, std, var
30 *
31 * @param {... *} args A single matrix or or multiple scalar values
32 * @return {*} The sum of all values
33 */
34 const sum = typed('sum', {
35 // sum([a, b, c, d, ...])
36 'Array | Matrix': _sum,
37
38 // sum([a, b, c, d, ...], dim)
39 'Array | Matrix, number | BigNumber': _nsumDim,
40
41 // sum(a, b, c, d, ...)
42 '...': function (args) {
43 if (containsCollections(args)) {
44 throw new TypeError('Scalar values expected in function sum')
45 }
46
47 return _sum(args)
48 }
49 })
50
51 sum.toTex = undefined // use default template
52
53 return sum
54
55 /**
56 * Recursively calculate the sum of an n-dimensional array
57 * @param {Array} array
58 * @return {number} sum
59 * @private
60 */
61 function _sum (array) {
62 let sum
63
64 deepForEach(array, function (value) {
65 try {
66 sum = (sum === undefined) ? value : add(sum, value)
67 } catch (err) {
68 throw improveErrorMessage(err, 'sum', value)
69 }
70 })
71
72 if (sum === undefined) {
73 switch (config.number) {
74 case 'number':
75 return 0
76 case 'BigNumber':
77 return new type.BigNumber(0)
78 case 'Fraction':
79 return new type.Fraction(0)
80 default:
81 return 0
82 }
83 }
84
85 return sum
86 }
87 function _nsumDim (array, dim) {
88 try {
89 const sum = reduce(array, dim, add)
90 return sum
91 } catch (err) {
92 throw improveErrorMessage(err, 'sum')
93 }
94 }
95}
96
97exports.name = 'sum'
98exports.factory = factory