1 | var NN = require("../neural");
|
2 | var R, M;
|
3 |
|
4 | var RBM = function (rlab, settings) {
|
5 | R = rlab;
|
6 | M = R.M;
|
7 | Object.assign(this, settings);
|
8 | this.settings = {
|
9 | 'log level' : 1
|
10 | };
|
11 | var a = 1. / this.nVisible;
|
12 | this.W = this.W || M.randomM(this.nVisible,this.nHidden,-a,a);
|
13 | this.hbias = this.hbias || M.newV(this.nHidden);
|
14 | this.vbias = this.vbias || M.newV(this.nVisible);
|
15 | }
|
16 |
|
17 | RBM.prototype.train = function(settings) {
|
18 | var lr = settings.lr||0.8;
|
19 | var k = settings.k||1;
|
20 | var epochs = settings.epochs||1500;
|
21 | this.input = settings.input||this.input;
|
22 |
|
23 | var i,j;
|
24 | var currentProgress = 1;
|
25 | for(i=0;i<epochs;i++) {
|
26 |
|
27 | var ph = this.sampleHgivenV(this.input);
|
28 | var phMean = ph[0], phSample = ph[1];
|
29 | var chainStart = phSample;
|
30 | var nvMeans, nvSamples, nhMeans, nhSamples;
|
31 |
|
32 | for(j=0 ; j<k ; j++) {
|
33 | if (j==0) {
|
34 | var gibbsVH = this.gibbsHVH(chainStart);
|
35 | nvMeans = gibbsVH[0], nvSamples = gibbsVH[1], nhMeans = gibbsVH[2], nhSamples = gibbsVH[3];
|
36 | } else {
|
37 | var gibbsVH = this.gibbsHVH(nhSamples);
|
38 | nvMeans = gibbsVH[0], nvSamples = gibbsVH[1], nhMeans = gibbsVH[2], nhSamples = gibbsVH[3];
|
39 | }
|
40 | }
|
41 |
|
42 | var deltaW = this.input.tr().dot(phMean).sub(nvSamples.tr().dot(nhMeans)).mul(1./this.input.length);
|
43 |
|
44 | var deltaVbias = this.input.sub(nvSamples).colMean();
|
45 |
|
46 | var deltaHbias = phSample.sub(nhMeans).colMean();
|
47 |
|
48 | this.W = this.W.add(deltaW.mul(lr));
|
49 |
|
50 | this.vbias = this.vbias.add(deltaVbias.dot(lr));
|
51 |
|
52 | this.hbias = this.hbias.add(deltaHbias.dot(lr));
|
53 | if(this.settings['log level'] > 0) {
|
54 | var progress = (1.*i/epochs)*100;
|
55 | if(progress > currentProgress) {
|
56 | console.log("RBM",progress.toFixed(0),"% Completed.");
|
57 | currentProgress+=8;
|
58 | }
|
59 | }
|
60 | }
|
61 | if(this.settings['log level'] > 0)
|
62 | console.log("RBM Final Cross Entropy : ",this.getReconstructionCrossEntropy())
|
63 | };
|
64 |
|
65 | RBM.prototype.propup = function(v) {
|
66 |
|
67 | return v.dot(this.W).addMV(this.hbias).mapM(NN.sigmoid);
|
68 | };
|
69 |
|
70 | RBM.prototype.probToBinaryMat = function(m) {
|
71 | return M.mapM(m, (x)=>(Math.random() < m[i][j])?1:0);
|
72 | };
|
73 |
|
74 | RBM.prototype.propdown = function(h) {
|
75 | return h.dot(this.W.tr()).addMV(this.vbias).mapM(NN.sigmoid);
|
76 | };
|
77 |
|
78 | RBM.prototype.sampleHgivenV = function(v0_sample) {
|
79 | var h1_mean = this.propup(v0_sample);
|
80 | var h1_sample = this.probToBinaryMat(h1_mean);
|
81 | return [h1_mean,h1_sample];
|
82 | };
|
83 |
|
84 | RBM.prototype.sampleVgivenH = function(h0_sample) {
|
85 | var v1_mean = this.propdown(h0_sample);
|
86 | var v1_sample = this.probToBinaryMat(v1_mean);
|
87 | return [v1_mean,v1_sample];
|
88 | };
|
89 |
|
90 | RBM.prototype.gibbsHVH = function(h0_sample) {
|
91 | var v1 = this.sampleVgivenH(h0_sample);
|
92 | var h1 = this.sampleHgivenV(v1[1]);
|
93 | return [v1[0],v1[1],h1[0],h1[1]];
|
94 | };
|
95 |
|
96 | RBM.prototype.reconstruct = function(v) {
|
97 | var h = v.dot(this.W).addMV(this.hbias).mapM(NN.sigmoid);
|
98 | return h.dot(this.W.tr()).addMV(this.vbias).mapM(NN.sigmoid);
|
99 | };
|
100 |
|
101 | RBM.prototype.getReconstructionCrossEntropy = function() {
|
102 | var reconstructedV = this.reconstruct(this.input);
|
103 | var a = M.mapMM(this.input, reconstructedV, function(x,y){
|
104 | return x*Math.log(y);
|
105 | });
|
106 |
|
107 | var b = M.mapMM(this.input,reconstructedV,function(x,y){
|
108 | return (1-x)*Math.log(1-y);
|
109 | });
|
110 | var crossEntropy = -a.add(b).rowSum().mean();
|
111 | return crossEntropy
|
112 |
|
113 | };
|
114 |
|
115 | RBM.prototype.set = function(property,value) {
|
116 | this.settings[property] = value;
|
117 | }
|
118 |
|
119 | module.exports = RBM; |
\ | No newline at end of file |