JavaScript版k-means++算法实现

按照维基百科的定义,k-means为:把n个点(可以是样本的一次观察或一个实例)划分到k个聚类中,使得每个点都属于离他最近的均值(此即聚类中心)对应的聚类。

k-means的过程为:
1、随机选取k个中心
2、将数据点依据中心归类为k个聚集
3、选择每个聚集的中心,作为新的中心
4、重复2、3步,直到聚集不再发生变化

k-means的缺陷为第一步的随机选择中心。如果数据点为[1, 5, 6, 9, 10]。k-means的结果可能为[[1], [5, 6], [9, 10]],也可能为[[1, 5, 6], [9], [10]]。从概念上看,这两个结果都是正确的,但后者很明显不是我们想要的。

k-means++改进了第一次选取中心的方法,使得第一次选取的中心之间间距够大。k-means++第一次选取中心的过程为:
1、随机选取一个数据点作为第一个中心
2、计算每个数据点到最近的中心的距离
3、选取上一步最后距离最大的数据点,作为新的中心
4、重复2、3步,直到选取到k个中心

下面是实现,包含了k-means和k-means++。
kmeans函数第三个参数为一个计算数据点权值的函数,用来计算数据点间距离和聚集平均权值。缺点是权值函数会被多次应用于同一数据点上。

exports = module.exports = kmeans;

//随机取k个中心
function randomCentroids(points, k, weightFun) {
    var centroids = new Array(k);
    var weights = new Array(k);
    var idxs = [];
    for (var i = 0; i < k; i++) {
        while (true) {
            var idx = parseInt(Math.random() * points.length);
            //下标不能重复 
            if (idxs.indexOf(idx) !== -1) continue;
            //权值不能相等 
            if (weights.indexOf(weightFun(points[idx])) !== -1) continue;
            break;
        }

        centroids[i] = points[idx];
        weights[i] = weightFun(points[idx]);
        idxs.push(idx);
    }
    return centroids;
}

//k-means++第一次选取中心 
function firstCentroids(points, k, weightFun) {
    var centroids = [];
    var m = k;

    //随机选出第一个中心 
    var first = points[parseInt(Math.random() * points.length)];
    centroids.push(first);
    m--;

    //选取剩下的中心 
    while (m > 0) {
        //每个点到最近中心的距离
        var dists = points.map(function(point) {
            var dists = centroids.map(function(centroid) {
                return Math.abs(weightFun(centroid) - weightFun(point))
            }) return Math.min.apply(null, dists);
        })
        //取上面的距离中最大者
        var max_dist = Math.max.apply(null, dists);
        var max_idx = dists.indexOf(max_dist);

        centroids.push(points[max_idx]);
        m--;
    }

    return centroids;
}

//新的k个中心
//计算每个中心的平均权值,取聚集中权值与平均权值最接近的为中心
function newCentroids(clusters, weightFun) {
    return clusters.map(function(cluster) {
        var sum = cluster.reduce(function(a, b) {
            return a + weightFun(b);
        },
        0) var mean = sum / cluster.length;

        var dists = cluster.map(function(point) {
            return Math.abs(weightFun(point) - mean);
        }) var min_dist = Math.min.apply(null, dists);
        return cluster[dists.indexOf(min_dist)];
    })
}

//聚类,返回k个聚集
function classify(points, centroids, weightFun) {
    var clusters = centroids.map(function() {
        return [];
    });

    for (var i = 0; i < points.length; i++) {
        var min_idx;
        var min_dist = undefined;
        for (var j = 0; j < centroids.length; j++) {
            dist = Math.abs(weightFun(points[i]) - weightFun(centroids[j]));
            if (min_dist === undefined || dist < min_dist) {
                min_dist = dist;
                min_idx = j;
            }
        }

        clusters[min_idx].push(points[i]);
    }

    return clusters;
}

function kmeans(points, k, weightFun) {
    var centroids = firstCentroids(points, k, weightFun);

    while (true) {
        var clusters = classify(points, centroids, weightFun);

        var old = centroids;
        var centroids = newCentroids(clusters, weightFun);

        //当新旧聚集中心相等时,聚类结果不会再发生变化,跳出循环
        for (var i = 0; i < k; i++) {
            if (centroids.indexOf(old[i]) === -1) {
                break;
            }
        }
        if (i === k) break;
    }

    return clusters;
}

if (!module.parent) {
    var points = [1, 5, 6, 9, 10, 34, 67, 12, 34, 67, 12, 344, 56, 23, 68, 23, 11, 333, 65, 23, 45, 23, 12];
    console.log(kmeans(points, 3,
    function(x) {
        return x;
    }))
}
This entry was posted in Uncategorized and tagged . Bookmark the permalink.

Leave a Reply

Your email address will not be published. Required fields are marked *