前文我们讲过使用Opencv现有的Kmeans聚类函数来获取COCO数据集anchor框尺寸:
基于libtorch的yolov5目标检测网络实现(3)——Kmeans聚类获取anchor框尺寸https://mp.weixin.qq.com/s/kQ7IOmluYwxRdLxX9okzZw https://mp.weixin.qq.com/s/kQ7IOmluYwxRdLxX9okzZw ;直接调用Opencv函数是很方便,不过存在一个问题:Opencv的Kmeans函数默认使用欧式距离来度量样本之间的距离,而且这是不可更改的。然而不同样本的宽、高差距通常比较大,使用欧式距离会导致聚类结果误差很大,因此yolo目标检测系列的作者改为使用iou来衡量样本距离,使得Kmeans聚类结果更准确稳定。
为了能够使用iou来实现样本距离度量,在本文中我们使用C++自己实现Kmeans算法。
01
为什么使用anchor框?
anchor框就是目标检测任务中目标框的先验信息。通俗理解,就是在网络训练之前,人为地先告诉网络目标框的信息范围,比如框中心坐标的取值范围、宽高的取值范围。然后网络在这个取值范围的基础上再去学习,以获得更精确的框信息。这样一来相当于对网络的学习方向作了限制,因此很大程度增加了网络的稳定性以及收敛速度。
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:9d388fe4-bd67-42e6-b0e6-a4fcf22c68bc
[En]
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:53cce362-c162-4d0f-800e-d769f6d0c1ad
02
json标签文件的解析
json是一种轻量级的数据交换格式,可以将不同信息打包成一个个模块,并将这些模块按照一定顺序存储到json文件中,读文件时只需要根据关键字对相应模块进行解析,即可得到该模块的打包信息。
COCO数据集把分类、位置等标签信息存储到json格式的文件,方便训练和测试时进行解析。
前文我们也详细介绍过json文件解析的C++实现:
03
Kmeans算法原理
Kmeans算法的基本思想是: 初始化K个中心点,然后计算每个样本点到所有中心点的距离,接着把样本划分到距离其最近的中心点。如下图所示,三个红点为中心点,若干黑点为样本,根据Kmeans算法思想,每个样本都被划分到距离其最近的红点,从而被划分到同一个红点的样本组成一个簇。
假设数据集有X0、X 1、X 2、…、X m-1这m个样本,其中每个样本Xi又是一个长度为n的一维向量:
Kmeans算法的基本步骤如下:
- 从m个样本中随机选择K个样本(C0、C 1、C 2、…、C k-1)作为中心点,这里的K为预先设定好的类别数,比如yolov5需要根据宽、高把训练集的所有目标框分成9类,那么K=9。
- 分别计算每个样本与K个中心点的距离,然后将该样本分配给距其最近的中心点。距离的度量通常使用欧式距离,比如对于任意样本Xi,分别计算其与C0、C 1、C 2、…、C k-1的欧式距离(如下式),然后比较d(Xi,C0)、d(Xi,C1)、d(Xi,C2)、…、d(Xi,Ck-1)得出距离样本Xi最小的中心点,并将样本Xi分配给该中心点。
- 判断步骤3得到的新中心点相对原中心点的的变化量是否小于设定阈值,小于阈值则停止计算。再判断迭代次数是否超过设定次数,如果超过也停止计算。否则回到第2步骤执行。
由于不同样本的宽、高差距较大,使用欧式距离会导致聚类结果误差很大,所以改为使用iou距离度量不同样本的差距。iou是衡量两个方框相似度(包括位置、宽、高的相似度)的量,是两个方框相交区域面积与相并部分面积的比值,所以也称为交并比。
如上图,两个方框的宽、高分别为(w1,h1)和(w2,h2),红色区域为两个框的相交区域,其宽、高为(w,h),”蓝+红+灰”区域为两个框的相并区域,那么相交区域的面积为:
相并区域的面积为:
那么得到iou:
iou的取值范围为[0,1],当两个框没有相交区域时iou取0,当两个框完全重合时iou取1。所以iou值越大说明两个框的形状和位置越相近。为了使度量值与相似度负相关,也即度量值越小相似度越大,我们对iou值取个负值并加1,得到:
上式就是使用Kmeans算法聚类目标框宽、高时使用的距离度量。 因为我们只关心框的宽、高,不关心它们的位置,所以计算iou时我们假设两个框的中心点是重合的,如下图所示:
04
Kmeans算法C++实现
在这里我们的任务是对目标检测的方框尺寸进行聚类,尺寸包含宽、高,也即每个样本有宽、高两个数据,所以可以使用点的形式来表达一个样本,该点的x坐标为宽、y坐标为高:
Point2f A; //A就是一个样本,A.x为宽,A.y为高
- 全局定义K值和中心点
#define K_NUM 9 //K=9,也即把所有目标框分成9类,最后得到的9个中心点就是9个anchor框
vector center_points(K_NUM); //定义9个中心点全局变量
- iou计算代码
//Point2f(w, h)
float cal_distance(Point2f A, Point2f B)
{
//求相交部分面积,假设两个方框的中心点重合
float S1 = std::min(A.x, B.x) * std::min(A.y, B.y);
//求相并部分面积
float S2 = A.x * A.y + B.x * B.y - S1;
float d = 1.0 - S1 / S2; //1-交并比
return d;
}
- 解析json文件获取所有样本,并初始化中心点
/*
根据图像id号,从json的images关键字中获取对应id号图像的宽、高
*/
void get_w_h(vector images_list, int img_id, int &w, int &h)
{
const int len = images_list.size(); //images关键字包含的图像总数
for (int i = 0; i < len; i++)
{
if (img_id == images_list[i].id) //查询id号匹配的图像
{
w = images_list[i].width; //得到宽、高
h = images_list[i].height;
break;
}
}
}
/*
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:52a75603-4ce0-47cd-8ae3-05922e4d84ac<details><summary>*<font color='gray'>[En]</font>*</summary>*<font color='gray'>[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:aa8a0b4d-233e-40c4-bdbf-115440563fed</font>*</details>
*/
void get_anchor_sample(vector &anchor_list)
{
//解析json文件
json j;
ifstream jfile("D:/数据/coco/annotations_trainval2017/annotations/instances_train2017.json");
jfile >> j;
ns::coco_label cr;
ns::from_json(j, cr);
anchor_list.clear();
for (int i = 0; i < cr.annotations_list.size(); i++)
{
cout << "i: " << i << endl;
int img_w, img_h;
//获取目标框对应图像的宽、高
get_w_h(cr.images_list, cr.annotations_list[i].image_id, img_w, img_h);
//将目标框的原宽、高转换为640*640像素下的宽、高
float w = cr.annotations_list[i].bbox[2] / img_w * 640.0;
float h = cr.annotations_list[i].bbox[3] / img_h * 640.0;
//将获取到的样本保存到数组中
anchor_list.push_back(Point2f(w, h));
}
const int len = anchor_list.size(); //总样本数
//随机使用9个不重复的样本来初始化中心点
srand(time(NULL));
vector idx_list(K_NUM);
idx_list[0] = rand() % len;
center_points[0] = anchor_list[idx_list[0]]; //随机获取第1个样本
//再随机获取后面8个不重复的样本
for (int i = 1; i < K_NUM; i++)
{
int idx;
while(1)
{
idx = rand() % len; //0~len-1
int j;
//该样本如果与前面的重复,则重新获取,确保随机获取的9个样本不重复
for (j = 0; j < i; j++)
{
if (idx == idx_list[j])
break;
}
if (j >= i)
{
idx_list[i] = idx;
break;
}
}
//将随机获取的样本赋值给对应中心点
center_points[i] = anchor_list[idx_list[i]];
}
}
- 更新中心点
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:9146a04e-71b0-4429-a2ef-edc0fbbbce44
[En]
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:141bc759-f908-4b1b-a543-e425f01d9c27
/*
class_list为二维数组,第1维表示不同的样本簇,第2维表示样本簇中的不同样本
*/
void cal_center_points(vector> class_list, vector ¢er_points_new)
{
center_points_new.clear(); //将中心点清除
for (int i = 0; i < class_list.size(); i++) //遍历所有样本簇
{
Point2f p = Point2f(0, 0);
for (int j = 0; j < class_list[i].size(); j++) //遍历每个样本簇中所有样本
{
p = p + class_list[i][j]; //累加和
}
p.x = p.x / class_list[i].size(); //求平均得到质心
p.y = p.y / class_list[i].size();
center_points_new.push_back(p); //将质心保存为中心点
}
}
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:d5f5921c-78fc-4675-a8ee-e9e8af790965
[En]
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:64de0e57-1a11-49c9-97e0-50f15b565de4
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:c31f938d-b2a2-45b1-8cbf-ee6d293f2d99
[En]
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:12ca496f-6ee3-42a2-a879-8b0f950efc5d
那么距离计算如下式:
代码实现:
float cal_center_points_distance(vector center_points_new, vector center_points)
{
float d = 0;
for (int i = 0; i < center_points.size(); i++) //遍历所有中心点
{
float diff_x = center_points_new[i].x - center_points[i].x;
float diff_y = center_points_new[i].y - center_points[i].y;
d += sqrt(diff_x * diff_x + diff_y * diff_y); //累计和
}
d = d / center_points.size(); //求平均
return d;
}
- 对每个样本分类
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:5f33eae9-bb9f-47d5-80f7-040a630a2d5a
[En]
[TencentCloudSDKException] code:FailedOperation.ServiceIsolate message:service is stopped due to arrears, please recharge your account in Tencent Cloud requestId:d0f9a2db-805f-4d6e-821f-d61b6c793135
int classify_points(Point2f p)
{
int idx = 0;
//计算该样本与第1个中心点的iou距离
float d = cal_distance(p, center_points[0]);
//计算该样本与后8个中心点的iou距离,得到最短距离的中心点并返回其索引
for (int i = 1; i < K_NUM; i++)
{
float di = cal_distance(p, center_points[i]);
if (d > di)
{
d = di;
idx = i;
}
}
return idx;
}
- 聚类实现
void kmean_classify_anchor(void)
{
vector anchor_list;
float EPS = 1e-5; //迭代精度,当中心点变化量小于该值则停止迭代
int iter_num = 500; //最大迭代次数,当迭代次数达到该值则停止迭代
//解析json文件获取所有样本,并初始化中心点
get_anchor_sample(anchor_list);
for (int k = 0; k > class_list(K_NUM);
//对每个样本分类
for (int i = 0; i < anchor_list.size(); i++)
{
int idx = classify_points(anchor_list[i]);
class_list[idx].push_back(anchor_list[i]);
}
//更新中心点
vector center_points_new;
cal_center_points(class_list, center_points_new);
//判断中心点变化量是否小于EPS,小于EPS则停止迭代
if (cal_center_points_distance(center_points_new, center_points) < EPS)
{
break;
}
//将新的中心点赋值给原中心点,实现替换
for (int i = 0; i < K_NUM; i++)
{
center_points[i] = center_points_new[i];
}
}
//打印中心点,也即最后得到的anchor框
for (int i = 0; i < K_NUM; i++)
{
cout << "w: " << center_points[i].x << " h: " << center_points[i].y << endl;
}
}
05
聚类结果
运行上述kmean_classify_anchor函数,得到9个anchor框如下:
根据从小到大的尺寸,把这些anchor框分别分配给yolov5的8080、4040、20*20网格:
80*80 (w, h):(19.2242, 29.5773),(48.5208, 71.5705),(71.705, 175.408)
40*40 (w, h):(121.482, 344.872),(149.95, 84.7272),(205.012, 195.934)
20*20 (w, h):(264.794, 439.859),(463.518, 231.48),(535.714, 525.014)
欢迎扫码关注本微信公众号,接下来会不定时更新更加精彩的内容,敬请期待~
Original: https://blog.csdn.net/shandianfengfan/article/details/121434110
Author: 萌萌哒程序猴
Title: C++实现Kmeans聚类算法获取COCO目标检测数据集的anchor框
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/560952/
转载文章受原作者版权保护。转载请注明原作者出处!