UNPKG

2.31 kBJavaScriptView Raw
1'use strict'
2
3function factory (type, config, load, typed) {
4 const matrix = load(require('../../type/matrix/function/matrix'))
5 const divide = load(require('../arithmetic/divide'))
6 const sum = load(require('../statistics/sum'))
7 const multiply = load(require('../arithmetic/multiply'))
8 const dotDivide = load(require('../arithmetic/dotDivide'))
9 const log = load(require('../arithmetic/log'))
10 const isNumeric = load(require('../utils/isNumeric'))
11
12 /**
13 * Calculate the Kullback-Leibler (KL) divergence between two distributions
14 *
15 * Syntax:
16 *
17 * math.kldivergence(x, y)
18 *
19 * Examples:
20 *
21 * math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5]) //returns 0.24376698773121153
22 *
23 *
24 * @param {Array | Matrix} q First vector
25 * @param {Array | Matrix} p Second vector
26 * @return {number} Returns distance between q and p
27 */
28 const kldivergence = typed('kldivergence', {
29 'Array, Array': function (q, p) {
30 return _kldiv(matrix(q), matrix(p))
31 },
32
33 'Matrix, Array': function (q, p) {
34 return _kldiv(q, matrix(p))
35 },
36
37 'Array, Matrix': function (q, p) {
38 return _kldiv(matrix(q), p)
39 },
40
41 'Matrix, Matrix': function (q, p) {
42 return _kldiv(q, p)
43 }
44
45 })
46
47 function _kldiv (q, p) {
48 const plength = p.size().length
49 const qlength = q.size().length
50 if (plength > 1) {
51 throw new Error('first object must be one dimensional')
52 }
53
54 if (qlength > 1) {
55 throw new Error('second object must be one dimensional')
56 }
57
58 if (plength !== qlength) {
59 throw new Error('Length of two vectors must be equal')
60 }
61
62 // Before calculation, apply normalization
63 const sumq = sum(q)
64 if (sumq === 0) {
65 throw new Error('Sum of elements in first object must be non zero')
66 }
67
68 const sump = sum(p)
69 if (sump === 0) {
70 throw new Error('Sum of elements in second object must be non zero')
71 }
72 const qnorm = divide(q, sum(q))
73 const pnorm = divide(p, sum(p))
74
75 const result = sum(multiply(qnorm, log(dotDivide(qnorm, pnorm))))
76 if (isNumeric(result)) {
77 return result
78 } else {
79 return Number.NaN
80 }
81 }
82
83 return kldivergence
84}
85
86exports.name = 'kldivergence'
87exports.factory = factory