ISODATA算法 python实现

文章目录

前言

ISODATA经常被用来与Kmeans算法进行对比,其本质也是按照欧式距离来对样本进行分类,不同的是ISODATA可以根据一个大概的指定类别数去确定最终的聚类数(两者可能不同),而Kmeans指定聚类数是多少后,最终的聚类就一定是多少。

一、ISODATA的流程

本质上只有分裂和合并两个步骤加更新中心三个步骤。了解这个算法,核心需要解决下面的三个问题:

Question 1. 什么时候分裂?

现有的聚类数太少就进行分裂。你一开始指定100个聚类,现在只有2个,那就进行分裂。(大的分裂方向,还有细节见下面流程图)

Question 2. 什么时候合并?

现有的聚类数太多就进行分裂。你一开始指定100个聚类,现在上一次刚好裂成200个,那就进行合并。(大的合并方向,还有细节见下面流程图)

Question 3. 现在有的中心数不上不下怎么办?

如果是奇数次迭代,那就尝试去分裂吧(虽然最后不一定分裂了)
如果是偶数次迭代,那就尝试去合并吧(虽然最后不一定合并了)

1.流程图(这里按迭代的奇偶来判断分裂或者合并)

ISODATA算法 python实现

注意:

在流程图中,”合并”步骤并不一定执行了合并,只有满足在所有的中心中,存在一些中心的距离太近(这个距离低于了设定的阈值)才会真正的执行合并的操作,其余不执行合并的操作。而在分裂中,只有现有的中心数太少或者满足”类内的距离太大而且样本数太少”进行分裂的操作。其中类内的距离太大则表示了这个聚类太过于松散,再加上类的数量太少的话,才进行分裂。

分裂的细节:如何分裂?

计算需要分裂的这个簇在各个维度上的方差,如果最大的方差超过了特定的阈值,就在这个最大方差的维度上分裂成两个,其他维度的值保持不变。

比如现在有一个中心 (1, 3) , 对于属于这个中心的所有样本,我们计算其在第一个维度 (数值1的维度) 的方差,再计算其在第二个维度 (数值3的维度) 的方差。假设维度1计算的方差结果为 0.3,维度2计算的方差为1.5,预先设定的阈值为0.5;所以我们要在第二个维度上把中心分成2个:(1, 3 + 1.5 * k ), (1, 3 – 1.5 *k) ,其中k又是控制分裂远近的一个超参数,在代码中取0.5。由此,我们得到了新分裂的两个中心,并把原来的中心去掉。

合并的细节: 如何合并?

合并使用加权平均的方法,两个权重是两个中心控制的两簇样本的数量百分比,加权求和即可。

; 二、使用步骤

1.代码实现

Tips: 注意需要用到sklearn的库来产生数据集:


"""
@author:zsiming
@fileName:ISODATA.py
@Time:2022/1/9  12:33
"""
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.metrics import euclidean_distances

class ISODATA():
    def __init__(self, designCenterNum, LeastSampleNum, StdThred, LeastCenterDist, iterationNum):

        self.K = designCenterNum
        self.thetaN = LeastSampleNum
        self.thetaS = StdThred
        self.thetaC = LeastCenterDist
        self.iteration = iterationNum

        self.n_samples = 1500

        self.random_state1 = 200
        self.random_state2 = 160
        self.random_state3 = 170
        self.data, self.label = make_blobs(n_samples=self.n_samples, random_state=self.random_state3)

        self.center = self.data[0, :].reshape((1, -1))
        self.centerNum = 1
        self.centerMeanDist = 0

        sns.set()

    def updateLabel(self):
"""
            更新中心
"""
        for i in range(self.centerNum):

            distance = euclidean_distances(self.data, self.center.reshape((self.centerNum, -1)))

            self.label = np.argmin(distance, 1)

            index = np.argwhere(self.label == i).squeeze()
            sameClassSample = self.data[index, :]

            self.center[i, :] = np.mean(sameClassSample, 0)

        for i in range(self.centerNum):

            index = np.argwhere(self.label == i).squeeze()
            sameClassSample = self.data[index, :]

            distance = np.mean(euclidean_distances(sameClassSample, self.center[i, :].reshape((1, -1))))

            self.centerMeanDist += distance
        self.centerMeanDist /= self.centerNum

    def divide(self):

        newCenterSet = self.center

        for i in range(self.centerNum):

            index = np.argwhere(self.label == i).squeeze()
            sameClassSample = self.data[index, :]

            stdEachDim = np.mean((sameClassSample - self.center[i, :])**2, axis=0)

            maxIndex = np.argmax(stdEachDim)
            maxStd = stdEachDim[maxIndex]

            distance = np.mean(euclidean_distances(sameClassSample, self.center[i, :].reshape((1, -1))))

            if maxStd > self.thetaS:

                if self.centerNum  self.K//2 or \
                        sameClassSample.shape[0] > 2 * (self.thetaN+1) and distance >= self.centerMeanDist:
                    newCenterFirst = self.center[i, :].copy()
                    newCenterSecond = self.center[i, :].copy()

                    newCenterFirst[maxIndex] += 0.5 * maxStd
                    newCenterSecond[maxIndex] -= 0.5 * maxStd

                    newCenterSet = np.delete(newCenterSet, i, axis=0)

                    newCenterSet = np.vstack((newCenterSet, newCenterFirst))
                    newCenterSet = np.vstack((newCenterSet, newCenterSecond))

            else:
                continue

        self.center = newCenterSet
        self.centerNum = self.center.shape[0]

    def combine(self):

        delIndexList = []

        centerDist = euclidean_distances(self.center, self.center)
        centerDist += (np.eye(self.centerNum)) * 10**10

        while True:

            minDist = np.min(centerDist)
            if minDist >= self.thetaC:
                break

            index = np.argmin(centerDist)
            row = index // self.centerNum
            col = index % self.centerNum

            index = np.argwhere(self.label == row).squeeze()
            classNumFirst = len(index)
            index = np.argwhere(self.label == col).squeeze()
            classNumSecond = len(index)
            newCenter = self.center[row, :] * (classNumFirst / (classNumFirst+ classNumSecond)) + \
                        self.center[col, :] * (classNumSecond / (classNumFirst+ classNumSecond))

            delIndexList.append(row)
            delIndexList.append(col)

            self.center = np.vstack((self.center, newCenter))
            self.centerNum -= 1

            centerDist[row, :] = float("inf")
            centerDist[col, :] = float("inf")
            centerDist[:, col] = float("inf")
            centerDist[:, row] = float("inf")

        self.center = np.delete(self.center, delIndexList, axis=0)
        self.centerNum = self.center.shape[0]

    def drawResult(self):
        ax = plt.gca()
        ax.clear()
        ax.scatter(self.data[:, 0], self.data[:, 1], c=self.label, cmap="cool")

        ax.set_xlabel('x axis')
        ax.set_ylabel('y axis')
        plt.show()

    def train(self):

        self.updateLabel()
        self.drawResult()

        for i in range(self.iteration):

            if self.centerNum < self.K //2:
                self.divide()
            elif (i > 0 and i % 2 == 0) or self.centerNum > 2 * self.K:
                self.combine()
            else:
                self.divide()

            self.updateLabel()
            self.drawResult()
            print("中心数量:{}".format(self.centerNum))

if __name__ == "__main__":
    isoData = ISODATA(designCenterNum=5, LeastSampleNum=20, StdThred=0.1, LeastCenterDist=2, iterationNum=20)
    isoData.train()

2.迭代过程

1. &#x539F;&#x59CB;&#x6570;&#x636E;&#x5982;&#x4E0B;&#x56FE;&#x6240;&#x793A;&#xFF0C;&#x53EF;&#x4EE5;&#x770B;&#x89C1;&#x6211;&#x5728;&#x8FD9;&#x513F;&#x6BD4;&#x8F83;&#x660E;&#x663E;&#x7684;&#x751F;&#x6210;&#x4E09;&#x4E2A;&#x7C07;&#x7684;&#x6570;&#x636E;&#xFF08;&#x7136;&#x540E;&#x6307;&#x5B9A;&#x7C7B;&#x522B;&#x6570;&#x4E3A;5&#xFF09;:

ISODATA算法 python实现

2. &#x4ECE;&#x4E00;&#x4E2A;&#x4E2D;&#x5FC3;&#x5206;&#x88C2;&#x6210;&#x4E3A;&#x4E24;&#x4E2A;&#x4E2D;&#x5FC3;&#xFF08;&#x7528;&#x989C;&#x8272;&#x533A;&#x5206;&#x4E0D;&#x540C;&#x7684;&#x805A;&#x7C7B;&#xFF09;&#xFF1A;

ISODATA算法 python实现

3. &#x672A;&#x5230;&#x8FBE;&#x6307;&#x5B9A;&#x7C7B;&#x522B;&#x6570;&#xFF08;2 < 5&#xFF09;&#x7EE7;&#x7EED;&#x5206;&#x88C2;&#x4E3A;4&#x4E2A;&#x4E2D;&#x5FC3;:

ISODATA算法 python实现

4.&#x4E2D;&#x5FC3;&#x8D34;&#x5F97;&#x592A;&#x8FD1;&#x4E86;&#xFF0C;&#x9700;&#x8981;&#x5408;&#x5E76;:

ISODATA算法 python实现

5. &#x66F4;&#x65B0;&#x4E2D;&#x5FC3;&#x7684;&#x4F4D;&#x7F6E;&#x548C;&#x5206;&#x88C2;&#xFF1A;

ISODATA算法 python实现

6.&#x4E2D;&#x5FC3;&#x8D34;&#x5F97;&#x592A;&#x8FD1;&#x4E86;&#xFF0C;&#x5408;&#x5E76;

ISODATA算法 python实现

7.&#x540E;&#x9762;&#x5C06;&#x4E0D;&#x518D;&#x53D8;&#x5316;&#x3002;

ISODATA算法 python实现

; 3. 总结

个人觉得:

从参数的角度来看,相比于Kmeans,由一个超参数数变成了六个超参数,不能说是改进。只能说某些先验知识比较完善的情况下,可能适用于数据流形分布比较复杂的场景。

Original: https://blog.csdn.net/zsiming/article/details/122410398
Author: zsiming
Title: ISODATA算法 python实现

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/549614/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球