1 | 'use strict'
|
2 |
|
3 | function factory (type, config, load, typed) {
|
4 | const matrix = load(require('../../type/matrix/function/matrix'))
|
5 | const divide = load(require('../arithmetic/divide'))
|
6 | const sum = load(require('../statistics/sum'))
|
7 | const multiply = load(require('../arithmetic/multiply'))
|
8 | const dotDivide = load(require('../arithmetic/dotDivide'))
|
9 | const log = load(require('../arithmetic/log'))
|
10 | const isNumeric = load(require('../utils/isNumeric'))
|
11 |
|
12 | |
13 |
|
14 |
|
15 |
|
16 |
|
17 |
|
18 |
|
19 |
|
20 |
|
21 |
|
22 |
|
23 |
|
24 |
|
25 |
|
26 |
|
27 |
|
28 | const kldivergence = typed('kldivergence', {
|
29 | 'Array, Array': function (q, p) {
|
30 | return _kldiv(matrix(q), matrix(p))
|
31 | },
|
32 |
|
33 | 'Matrix, Array': function (q, p) {
|
34 | return _kldiv(q, matrix(p))
|
35 | },
|
36 |
|
37 | 'Array, Matrix': function (q, p) {
|
38 | return _kldiv(matrix(q), p)
|
39 | },
|
40 |
|
41 | 'Matrix, Matrix': function (q, p) {
|
42 | return _kldiv(q, p)
|
43 | }
|
44 |
|
45 | })
|
46 |
|
47 | function _kldiv (q, p) {
|
48 | const plength = p.size().length
|
49 | const qlength = q.size().length
|
50 | if (plength > 1) {
|
51 | throw new Error('first object must be one dimensional')
|
52 | }
|
53 |
|
54 | if (qlength > 1) {
|
55 | throw new Error('second object must be one dimensional')
|
56 | }
|
57 |
|
58 | if (plength !== qlength) {
|
59 | throw new Error('Length of two vectors must be equal')
|
60 | }
|
61 |
|
62 |
|
63 | const sumq = sum(q)
|
64 | if (sumq === 0) {
|
65 | throw new Error('Sum of elements in first object must be non zero')
|
66 | }
|
67 |
|
68 | const sump = sum(p)
|
69 | if (sump === 0) {
|
70 | throw new Error('Sum of elements in second object must be non zero')
|
71 | }
|
72 | const qnorm = divide(q, sum(q))
|
73 | const pnorm = divide(p, sum(p))
|
74 |
|
75 | const result = sum(multiply(qnorm, log(dotDivide(qnorm, pnorm))))
|
76 | if (isNumeric(result)) {
|
77 | return result
|
78 | } else {
|
79 | return Number.NaN
|
80 | }
|
81 | }
|
82 |
|
83 | return kldivergence
|
84 | }
|
85 |
|
86 | exports.name = 'kldivergence'
|
87 | exports.factory = factory
|