1 | import euclideanDistance from "./euclidean_distance";
|
2 | import makeMatrix from "./make_matrix";
|
3 | import sample from "./sample";
|
4 |
|
5 | /**
|
6 | * @typedef {Object} kMeansReturn
|
7 | * @property {Array<number>} labels The labels.
|
8 | * @property {Array<Array<number>>} centroids The cluster centroids.
|
9 | */
|
10 |
|
11 | /**
|
12 | * Perform k-means clustering.
|
13 | *
|
14 | * @param {Array<Array<number>>} points N-dimensional coordinates of points to be clustered.
|
15 | * @param {number} numCluster How many clusters to create.
|
16 | * @param {Function} randomSource An optional entropy source that generates uniform values in [0, 1).
|
17 | * @return {kMeansReturn} Labels (same length as data) and centroids (same length as numCluster).
|
18 | * @throws {Error} If any centroids wind up friendless (i.e., without associated points).
|
19 | *
|
20 | * @example
|
21 | * kMeansCluster([[0.0, 0.5], [1.0, 0.5]], 2); // => {labels: [0, 1], centroids: [[0.0, 0.5], [1.0 0.5]]}
|
22 | */
|
23 | function kMeansCluster(points, numCluster, randomSource = Math.random) {
|
24 | let oldCentroids = null;
|
25 | let newCentroids = sample(points, numCluster, randomSource);
|
26 | let labels = null;
|
27 | let change = Number.MAX_VALUE;
|
28 | while (change !== 0) {
|
29 | labels = labelPoints(points, newCentroids);
|
30 | oldCentroids = newCentroids;
|
31 | newCentroids = calculateCentroids(points, labels, numCluster);
|
32 | change = calculateChange(newCentroids, oldCentroids);
|
33 | }
|
34 | return {
|
35 | labels: labels,
|
36 | centroids: newCentroids
|
37 | };
|
38 | }
|
39 |
|
40 | /**
|
41 | * Label each point according to which centroid it is closest to.
|
42 | *
|
43 | * @private
|
44 | * @param {Array<Array<number>>} points Array of XY coordinates.
|
45 | * @param {Array<Array<number>>} centroids Current centroids.
|
46 | * @return {Array<number>} Group labels.
|
47 | */
|
48 | function labelPoints(points, centroids) {
|
49 | return points.map((p) => {
|
50 | let minDist = Number.MAX_VALUE;
|
51 | let label = -1;
|
52 | for (let i = 0; i < centroids.length; i++) {
|
53 | const dist = euclideanDistance(p, centroids[i]);
|
54 | if (dist < minDist) {
|
55 | minDist = dist;
|
56 | label = i;
|
57 | }
|
58 | }
|
59 | return label;
|
60 | });
|
61 | }
|
62 |
|
63 | /**
|
64 | * Calculate centroids for points given labels.
|
65 | *
|
66 | * @private
|
67 | * @param {Array<Array<number>>} points Array of XY coordinates.
|
68 | * @param {Array<number>} labels Which groups points belong to.
|
69 | * @param {number} numCluster Number of clusters being created.
|
70 | * @return {Array<Array<number>>} Centroid for each group.
|
71 | * @throws {Error} If any centroids wind up friendless (i.e., without associated points).
|
72 | */
|
73 | function calculateCentroids(points, labels, numCluster) {
|
74 | // Initialize accumulators.
|
75 | const dimension = points[0].length;
|
76 | const centroids = makeMatrix(numCluster, dimension);
|
77 | const counts = Array(numCluster).fill(0);
|
78 |
|
79 | // Add points to centroids' accumulators and count points per centroid.
|
80 | const numPoints = points.length;
|
81 | for (let i = 0; i < numPoints; i++) {
|
82 | const point = points[i];
|
83 | const label = labels[i];
|
84 | const current = centroids[label];
|
85 | for (let j = 0; j < dimension; j++) {
|
86 | current[j] += point[j];
|
87 | }
|
88 | counts[label] += 1;
|
89 | }
|
90 |
|
91 | // Rescale centroids, checking for any that have no points.
|
92 | for (let i = 0; i < numCluster; i++) {
|
93 | if (counts[i] === 0) {
|
94 | throw new Error(`Centroid ${i} has no friends`);
|
95 | }
|
96 | const centroid = centroids[i];
|
97 | for (let j = 0; j < dimension; j++) {
|
98 | centroid[j] /= counts[i];
|
99 | }
|
100 | }
|
101 |
|
102 | return centroids;
|
103 | }
|
104 |
|
105 | /**
|
106 | * Calculate the difference between old centroids and new centroids.
|
107 | *
|
108 | * @private
|
109 | * @param {Array<Array<number>>} left One list of centroids.
|
110 | * @param {Array<Array<number>>} right Another list of centroids.
|
111 | * @return {number} Distance between centroids.
|
112 | */
|
113 | function calculateChange(left, right) {
|
114 | let total = 0;
|
115 | for (let i = 0; i < left.length; i++) {
|
116 | total += euclideanDistance(left[i], right[i]);
|
117 | }
|
118 | return total;
|
119 | }
|
120 |
|
121 | export default kMeansCluster;
|