按照维基百科的定义,k-means为:把n个点(可以是样本的一次观察或一个实例)划分到k个聚类中,使得每个点都属于离他最近的均值(此即聚类中心)对应的聚类。
k-means的过程为:
1、随机选取k个中心
2、将数据点依据中心归类为k个聚集
3、选择每个聚集的中心,作为新的中心
4、重复2、3步,直到聚集不再发生变化
k-means的缺陷为第一步的随机选择中心。如果数据点为[1, 5, 6, 9, 10]。k-means的结果可能为[[1], [5, 6], [9, 10]],也可能为[[1, 5, 6], [9], [10]]。从概念上看,这两个结果都是正确的,但后者很明显不是我们想要的。
k-means++改进了第一次选取中心的方法,使得第一次选取的中心之间间距够大。k-means++第一次选取中心的过程为:
1、随机选取一个数据点作为第一个中心
2、计算每个数据点到最近的中心的距离
3、选取上一步最后距离最大的数据点,作为新的中心
4、重复2、3步,直到选取到k个中心
下面是实现,包含了k-means和k-means++。
kmeans函数第三个参数为一个计算数据点权值的函数,用来计算数据点间距离和聚集平均权值。缺点是权值函数会被多次应用于同一数据点上。
exports = module.exports = kmeans;
//随机取k个中心
function randomCentroids(points, k, weightFun) {
var centroids = new Array(k);
var weights = new Array(k);
var idxs = [];
for (var i = 0; i < k; i++) {
while (true) {
var idx = parseInt(Math.random() * points.length);
//下标不能重复
if (idxs.indexOf(idx) !== -1) continue;
//权值不能相等
if (weights.indexOf(weightFun(points[idx])) !== -1) continue;
break;
}
centroids[i] = points[idx];
weights[i] = weightFun(points[idx]);
idxs.push(idx);
}
return centroids;
}
//k-means++第一次选取中心
function firstCentroids(points, k, weightFun) {
var centroids = [];
var m = k;
//随机选出第一个中心
var first = points[parseInt(Math.random() * points.length)];
centroids.push(first);
m--;
//选取剩下的中心
while (m > 0) {
//每个点到最近中心的距离
var dists = points.map(function(point) {
var dists = centroids.map(function(centroid) {
return Math.abs(weightFun(centroid) - weightFun(point))
}) return Math.min.apply(null, dists);
})
//取上面的距离中最大者
var max_dist = Math.max.apply(null, dists);
var max_idx = dists.indexOf(max_dist);
centroids.push(points[max_idx]);
m--;
}
return centroids;
}
//新的k个中心
//计算每个中心的平均权值,取聚集中权值与平均权值最接近的为中心
function newCentroids(clusters, weightFun) {
return clusters.map(function(cluster) {
var sum = cluster.reduce(function(a, b) {
return a + weightFun(b);
},
0) var mean = sum / cluster.length;
var dists = cluster.map(function(point) {
return Math.abs(weightFun(point) - mean);
}) var min_dist = Math.min.apply(null, dists);
return cluster[dists.indexOf(min_dist)];
})
}
//聚类,返回k个聚集
function classify(points, centroids, weightFun) {
var clusters = centroids.map(function() {
return [];
});
for (var i = 0; i < points.length; i++) {
var min_idx;
var min_dist = undefined;
for (var j = 0; j < centroids.length; j++) {
dist = Math.abs(weightFun(points[i]) - weightFun(centroids[j]));
if (min_dist === undefined || dist < min_dist) {
min_dist = dist;
min_idx = j;
}
}
clusters[min_idx].push(points[i]);
}
return clusters;
}
function kmeans(points, k, weightFun) {
var centroids = firstCentroids(points, k, weightFun);
while (true) {
var clusters = classify(points, centroids, weightFun);
var old = centroids;
var centroids = newCentroids(clusters, weightFun);
//当新旧聚集中心相等时,聚类结果不会再发生变化,跳出循环
for (var i = 0; i < k; i++) {
if (centroids.indexOf(old[i]) === -1) {
break;
}
}
if (i === k) break;
}
return clusters;
}
if (!module.parent) {
var points = [1, 5, 6, 9, 10, 34, 67, 12, 34, 67, 12, 344, 56, 23, 68, 23, 11, 333, 65, 23, 45, 23, 12];
console.log(kmeans(points, 3,
function(x) {
return x;
}))
}