不平衡数据分类网络-Pytorch试验

不平衡数据分类网络-Pytorch试验

注意:本试验在参考此代码的基础上。为方便起见,之后简称A

1.1 制作不平衡数据集 (下载的为平衡数据集)

脚本:cifar10_to_png.py脚本:image2train_test.py

直接从原始CIFAR-10采样,通过控制每一类采样的个数,就可以产生类别不平衡的训练数据。
步骤
1)在A提取图片的基础上 ;

2)将数据集分成训练集和测试集 ;

3)在训练集中根据自定义的类别占比,采样不同数量的类别,得到不平衡训练集;

4)在测试集中,采样相同小数量的类别,得到平衡测试集。

PS:为了尽可能近似实际项目中的情况,故训练集中的样本数量设置的比较少。
且第二步的意义是为了防止数据泄露。

2. 数据加载 (参考A)

3. 搭建网络 (参考A)

采用的VGG16网络 参考此博客介绍

4. 训练网络

4.1 训练普通交叉熵损失函数的网络

loss = celoss(outputs, labels)

4.2 训练Class-Balanced Loss 的网络

Class-Balanced Loss Based on Effective Number of Samples论文解读参考此博客

不平衡数据分类网络-Pytorch试验
β \beta β为常数,论文中设置为( N − 1 ) / N (N-1)/N (N −1 )/N,N N N 为总样本数目。n y n_y n y ​ 为第 y y y 类的样本数目。

训练时遇到bug:UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at …\c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

解决办法:
这是pytorch1.9的bug,下个版本将修复,我将pytorch降级成1.8就不报这个错了。

; 5. 训练结果

5.1 第一组试验

数据集:
1)训练集:10类不平衡样本按如下比例分配

trainnum = 1000
class_ratio = [19, 17, 15, 13, 11, 9, 7, 5, 3, 1]

2)测试集:10类平衡样本每类数量为:

testnum = 50

混淆矩阵如有不懂参考此博客具体代码实现

e p o c h = 500 epoch = 500 e p o c h =5 0 0 时,在测试集上得到的混淆矩阵如下:

不平衡数据分类网络-Pytorch试验
e p o c h = 500 epoch = 500 e p o c h =5 0 0 时,利用类平衡损失函数,在测试集得到的混淆矩阵为:
不平衡数据分类网络-Pytorch试验
不平衡数据分类网络-Pytorch试验

图1 交叉熵损失函数

不平衡数据分类网络-Pytorch试验

图2 类平衡损失函数

e p o c h = 2500 epoch = 2500 e p o c h =2 5 0 0 时,在测试集上得到的混淆矩阵如下:

不平衡数据分类网络-Pytorch试验
e p o c h = 2000 epoch = 2000 e p o c h =2 0 0 0 时,利用类平衡损失函数,在测试集得到的混淆矩阵为:
不平衡数据分类网络-Pytorch试验
不平衡数据分类网络-Pytorch试验

图1 交叉熵损失函数

不平衡数据分类网络-Pytorch试验

图2 类平衡损失函数

结论:类平衡损失函数效果不明显。

可能有如下原因:

1)整体样本数量不是特别多,同类样本之间的特征不是特别统一。后续补做试验

2)没根据Loss去判断网络是否收敛。后续修改程序
5.2 第二组实验
将训练集扩大至5000个,测试集仍然是50 ∗ 10 50*10 5 0 ∗1 0,加入loss曲线,e p o c h = 1000 epoch = 1000 e p o c h =1 0 0 0 利用类平衡损失函数,结果如下:

不平衡数据分类网络-Pytorch试验

不平衡数据分类网络-Pytorch试验
不平衡数据分类网络-Pytorch试验

图1 训练集=1000

不平衡数据分类网络-Pytorch试验

图2 训练集=5000

召回率与精度如下:

labelname  recall-5000  recall-1000  precision-5000 precision-1000
airplane    52.0%           24.3%       44.8%           36.0%
automobile  90.0%           35.5%       48.9%           76.0%
bird        32.0%           29.3%       69.6%           34.0%
cat         38.0%           30.4%       24.4%           34.0%
deer        46.0%           31.3%       46.9%           42.0%
dog         76.0%           44.1%       34.9%           30.0%
frog        26.0%           51.6%       65.0%           64.0%
horse       32.0%           32.0%       80.0%           16.0%
ship        48.0%           84.6%       57.1%           22.0%
truck       14.0%           25.0%       77.8%           2.0%

结论:扩大训练集,训练效果更好。

5.3 第三组实验
利用图像增强将不平衡训练集,调整至平衡数据集。

不平衡数据分类网络-Pytorch试验
不平衡数据分类网络-Pytorch试验
不平衡数据分类网络-Pytorch试验

图1 训练集=5000利用类平衡损失函数

不平衡数据分类网络-Pytorch试验

图2 训练集=5000用图像增强实现重采样

召回率与精度如下:

labelname  recall-lossblance  recall-resample  precision-lossbalance    precision-resample
airplane        52.0%           92.0%               44.8%                       24.9%
automobile      90.0%           92.0%               48.9%                       59.0%
bird            32.0%           38.0%               69.6%                       65.5%
cat             38.0%           56.0%               24.4%                       32.6%
deer            46.0%           48.0%               46.9%                       60.0%
dog             76.0%           30.0%               34.9%                       65.2%
frog            26.0%           38.0%               65.0%                       76.0%
horse           32.0%           32.0%               80.0%                       76.2%
ship            48.0%           20.0%               57.1%                       83.3%
truck           14.0%           2.0%                77.8%                       100.0%

结论:综合来说,重采样效果更好,不过也可能是由于重采样的原因,导致小类样本训练可能存在过拟合(对有限的样本特征学习的很好,反而不敢预测),导致其召回率很低,精度可以。

5.4 第四组实验

结合重采样和重加权的方法,进行训练,结果如下:

不平衡数据分类网络-Pytorch试验

不平衡数据分类网络-Pytorch试验
不平衡数据分类网络-Pytorch试验

图1 用图像增强实现重采样

不平衡数据分类网络-Pytorch试验

图2 重采样和重加权结合

召回率与精度如下:

labelname  recall-lossblance-rs  recall-resample  precision-lossbalance-rs  precision-resample
airplane        52.0%           92.0%               35.1%                       24.9%
automobile      90.0%           92.0%               45.9%                       59.0%
bird            50.0%           38.0%               42.4%                       65.5%
cat             56.0%           56.0%               27.2%                       32.6%
deer            48.0%           48.0%               33.3%                       60.0%
dog             22.0%           30.0%               55.0%                       65.2%
frog            56.0%           38.0%               60.9%                       76.0%
horse           34.0%           32.0%               70.8%                       76.2%
ship            4.0%            20.0%               66.7%                       83.3%
truck           2.0%            2.0%                100.0%                      100.0%

结论:效果似乎没有单纯重采样效果好
原因:???
下一步寻找针对样本不平衡问题的评价指标

Original: https://blog.csdn.net/a1315sf/article/details/121095509
Author: HNU_土间太平
Title: 不平衡数据分类网络-Pytorch试验

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/662895/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球