UNPKG

3.74 kBJavaScriptView Raw
1var NN = require("../neural");
2var R, M;
3
4var RBM = function (rlab, settings) {
5 R = rlab;
6 M = R.M;
7 Object.assign(this, settings);
8 this.settings = {
9 'log level' : 1 // 0 : nothing, 1 : info, 2: warn
10 };
11 var a = 1. / this.nVisible;
12 this.W = this.W || M.randomM(this.nVisible,this.nHidden,-a,a);
13 this.hbias = this.hbias || M.newV(this.nHidden);
14 this.vbias = this.vbias || M.newV(this.nVisible);
15}
16
17RBM.prototype.train = function(settings) {
18 var lr = settings.lr||0.8;
19 var k = settings.k||1;
20 var epochs = settings.epochs||1500;
21 this.input = settings.input||this.input;
22
23 var i,j;
24 var currentProgress = 1;
25 for(i=0;i<epochs;i++) {
26 /* CD - k . Contrastive Divergence */
27 var ph = this.sampleHgivenV(this.input);
28 var phMean = ph[0], phSample = ph[1];
29 var chainStart = phSample;
30 var nvMeans, nvSamples, nhMeans, nhSamples;
31
32 for(j=0 ; j<k ; j++) {
33 if (j==0) {
34 var gibbsVH = this.gibbsHVH(chainStart);
35 nvMeans = gibbsVH[0], nvSamples = gibbsVH[1], nhMeans = gibbsVH[2], nhSamples = gibbsVH[3];
36 } else {
37 var gibbsVH = this.gibbsHVH(nhSamples);
38 nvMeans = gibbsVH[0], nvSamples = gibbsVH[1], nhMeans = gibbsVH[2], nhSamples = gibbsVH[3];
39 }
40 }
41 // ((input^t*phMean)-(nvSample^t*nhMeans))*1/input.length
42 var deltaW = this.input.tr().dot(phMean).sub(nvSamples.tr().dot(nhMeans)).mul(1./this.input.length);
43 // deltaW = (input*phMean)-(nvSample^t * nhMeans)*1/input.length
44 var deltaVbias = this.input.sub(nvSamples).colMean();
45 // deltaHbias = (phSample - nhMeans).mean(row)
46 var deltaHbias = phSample.sub(nhMeans).colMean();
47 // W += deltaW*lr
48 this.W = this.W.add(deltaW.mul(lr));
49 // vbias += deltaVbias*lr
50 this.vbias = this.vbias.add(deltaVbias.dot(lr));
51 // hbias += deltaHbias*lr
52 this.hbias = this.hbias.add(deltaHbias.dot(lr));
53 if(this.settings['log level'] > 0) {
54 var progress = (1.*i/epochs)*100;
55 if(progress > currentProgress) {
56 console.log("RBM",progress.toFixed(0),"% Completed.");
57 currentProgress+=8;
58 }
59 }
60 }
61 if(this.settings['log level'] > 0)
62 console.log("RBM Final Cross Entropy : ",this.getReconstructionCrossEntropy())
63};
64
65RBM.prototype.propup = function(v) {
66 // sigmoid(v*W+hbias)
67 return v.dot(this.W).addMV(this.hbias).mapM(NN.sigmoid);
68};
69
70RBM.prototype.probToBinaryMat = function(m) {
71 return M.mapM(m, (x)=>(Math.random() < m[i][j])?1:0);
72};
73
74RBM.prototype.propdown = function(h) {
75 return h.dot(this.W.tr()).addMV(this.vbias).mapM(NN.sigmoid);
76};
77
78RBM.prototype.sampleHgivenV = function(v0_sample) {
79 var h1_mean = this.propup(v0_sample);
80 var h1_sample = this.probToBinaryMat(h1_mean);
81 return [h1_mean,h1_sample];
82};
83
84RBM.prototype.sampleVgivenH = function(h0_sample) {
85 var v1_mean = this.propdown(h0_sample);
86 var v1_sample = this.probToBinaryMat(v1_mean);
87 return [v1_mean,v1_sample];
88};
89
90RBM.prototype.gibbsHVH = function(h0_sample) {
91 var v1 = this.sampleVgivenH(h0_sample);
92 var h1 = this.sampleHgivenV(v1[1]);
93 return [v1[0],v1[1],h1[0],h1[1]];
94};
95
96RBM.prototype.reconstruct = function(v) {
97 var h = v.dot(this.W).addMV(this.hbias).mapM(NN.sigmoid);
98 return h.dot(this.W.tr()).addMV(this.vbias).mapM(NN.sigmoid);
99};
100
101RBM.prototype.getReconstructionCrossEntropy = function() {
102 var reconstructedV = this.reconstruct(this.input);
103 var a = M.mapMM(this.input, reconstructedV, function(x,y){
104 return x*Math.log(y);
105 });
106
107 var b = M.mapMM(this.input,reconstructedV,function(x,y){
108 return (1-x)*Math.log(1-y);
109 });
110 var crossEntropy = -a.add(b).rowSum().mean();
111 return crossEntropy
112
113};
114
115RBM.prototype.set = function(property,value) {
116 this.settings[property] = value;
117}
118
119module.exports = RBM;
\No newline at end of file