UNPKG

4 kBJavaScriptView Raw
1import euclideanDistance from "./euclidean_distance";
2import makeMatrix from "./make_matrix";
3import sample from "./sample";
4
5/**
6 * @typedef {Object} kMeansReturn
7 * @property {Array<number>} labels The labels.
8 * @property {Array<Array<number>>} centroids The cluster centroids.
9 */
10
11/**
12 * Perform k-means clustering.
13 *
14 * @param {Array<Array<number>>} points N-dimensional coordinates of points to be clustered.
15 * @param {number} numCluster How many clusters to create.
16 * @param {Function} randomSource An optional entropy source that generates uniform values in [0, 1).
17 * @return {kMeansReturn} Labels (same length as data) and centroids (same length as numCluster).
18 * @throws {Error} If any centroids wind up friendless (i.e., without associated points).
19 *
20 * @example
21 * kMeansCluster([[0.0, 0.5], [1.0, 0.5]], 2); // => {labels: [0, 1], centroids: [[0.0, 0.5], [1.0 0.5]]}
22 */
23function kMeansCluster(points, numCluster, randomSource = Math.random) {
24 let oldCentroids = null;
25 let newCentroids = sample(points, numCluster, randomSource);
26 let labels = null;
27 let change = Number.MAX_VALUE;
28 while (change !== 0) {
29 labels = labelPoints(points, newCentroids);
30 oldCentroids = newCentroids;
31 newCentroids = calculateCentroids(points, labels, numCluster);
32 change = calculateChange(newCentroids, oldCentroids);
33 }
34 return {
35 labels: labels,
36 centroids: newCentroids
37 };
38}
39
40/**
41 * Label each point according to which centroid it is closest to.
42 *
43 * @private
44 * @param {Array<Array<number>>} points Array of XY coordinates.
45 * @param {Array<Array<number>>} centroids Current centroids.
46 * @return {Array<number>} Group labels.
47 */
48function labelPoints(points, centroids) {
49 return points.map((p) => {
50 let minDist = Number.MAX_VALUE;
51 let label = -1;
52 for (let i = 0; i < centroids.length; i++) {
53 const dist = euclideanDistance(p, centroids[i]);
54 if (dist < minDist) {
55 minDist = dist;
56 label = i;
57 }
58 }
59 return label;
60 });
61}
62
63/**
64 * Calculate centroids for points given labels.
65 *
66 * @private
67 * @param {Array<Array<number>>} points Array of XY coordinates.
68 * @param {Array<number>} labels Which groups points belong to.
69 * @param {number} numCluster Number of clusters being created.
70 * @return {Array<Array<number>>} Centroid for each group.
71 * @throws {Error} If any centroids wind up friendless (i.e., without associated points).
72 */
73function calculateCentroids(points, labels, numCluster) {
74 // Initialize accumulators.
75 const dimension = points[0].length;
76 const centroids = makeMatrix(numCluster, dimension);
77 const counts = Array(numCluster).fill(0);
78
79 // Add points to centroids' accumulators and count points per centroid.
80 const numPoints = points.length;
81 for (let i = 0; i < numPoints; i++) {
82 const point = points[i];
83 const label = labels[i];
84 const current = centroids[label];
85 for (let j = 0; j < dimension; j++) {
86 current[j] += point[j];
87 }
88 counts[label] += 1;
89 }
90
91 // Rescale centroids, checking for any that have no points.
92 for (let i = 0; i < numCluster; i++) {
93 if (counts[i] === 0) {
94 throw new Error(`Centroid ${i} has no friends`);
95 }
96 const centroid = centroids[i];
97 for (let j = 0; j < dimension; j++) {
98 centroid[j] /= counts[i];
99 }
100 }
101
102 return centroids;
103}
104
105/**
106 * Calculate the difference between old centroids and new centroids.
107 *
108 * @private
109 * @param {Array<Array<number>>} left One list of centroids.
110 * @param {Array<Array<number>>} right Another list of centroids.
111 * @return {number} Distance between centroids.
112 */
113function calculateChange(left, right) {
114 let total = 0;
115 for (let i = 0; i < left.length; i++) {
116 total += euclideanDistance(left[i], right[i]);
117 }
118 return total;
119}
120
121export default kMeansCluster;