CenterNet根据自己的数据训练模型

本文参考:

1、数据集相关的:https://blog.csdn.net/weixin_42634342/article/details/97697356

2、训练自己的模型参考:

https://bbs.huaweicloud.com/blogs/210374

https://www.huaweicloud.com/articles/ebb05fa50237d7ac7ad6c0b29e38f969.html

https://blog.csdn.net/jiangpeng59/article/details/105732166

3、CenterNet官方安装文档:https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md

一、数据集处理

1、训练集下载

SeaShips数据集,链接为:http://www.lmars.whu.edu.cn/prof_web/shaozhenfeng/datasets/SeaShips(7000).zip,

如果在linux上直接wget获取即可

该版本数据集共有7000张图片,图片分辨率均为1920*1080,分为六类船只,主要是一些内河道中船只的图片。

该数据集为PascalVOC数据集。

2、voc数据集格式转coco格式

centernet虽然同时支持coco和voc数据集,但是本次我们需要转成coco格式的数据集,方便后续进行操作

转换的参考代码为:https://blog.csdn.net/yang332233/article/details/97205112

只需要修改最后几行加粗的位置相关的代码即可。

转换完毕之后的json文件后续会放到centernet对应的目录下面

二、CenterNet代码编译

1、安装python3.6环境

2、安装pytorch 1.x以及pytorchvision 0.x版本

3、安装cocoapi

COCOAPI=/path/to/clone/cocoapi

git clone https://github.com/cocodataset/cocoapi.git $COCOAPI

cd $COCOAPI/PythonAPI

make

python setup.py install –user

4、下载CenterNet

CenterNet_ROOT=/path/to/clone/CenterNet

git clone https://github.com/xingyizhou/CenterNet $CenterNet_ROOT

5、安装CenterNet依赖的python包

pip install -r requirements.txt

6、编译DCNv2

CenterNet自带的DCNv2只支持pytorch0.4,随意会导致后续编译不成功,所以需要删除$CenterNet_ROOT/src/lib/models/networks/DCNv2目录,然后重新下载最新的DCNv2代码进行编译。

对应的git地址是:git clone https://github.com/CharlesShang/DCNv2.git

下载完成后进行编译:

./make.sh

7、编译NMS组件

cd $CenterNet_ROOT/src/lib/external

make

三、使用已有的线上模型进行预测

1、下载训练好的模型

模型下载地址见:https://github.com/xingyizhou/CenterNet/blob/master/readme/MODEL_ZOO.md,

如果做目标检测,可以下载下图所指的模型,这个模型大概是77M。

模型下载后放到models目录下即可。

CenterNet根据自己的数据训练模型

2、修改结果输出方式

centernet默认是把预测结果图片输出到screen,但是docker中无法显示导致报错,所以需要修改输出方式。

修改 src/lib/detectors/cdet.py。

将debugger.show_all_imgs(pause=self.pause) 注释掉,

换成debugger.save_all_imgs(path=’/home/jhsu/sujh/ljj/CenterNet/output’, genID=True)

如下图所示:

CenterNet根据自己的数据训练模型

3、图片进行目标检测

随便网上找一张图片,执行以下命令:

python src/demo.py ctdet –demo images/17790319373_bd19b24cfc_k.jpg –load_model models/ctdet_coco_dla_2x.pth

运行出错可参考:http://blog.sina.com.cn/s/blog_628cc2b70102ysyi.html,相关问题可以参考进行解决

执行完毕后会在output目录下输出结果,一般是xctdet.png的文件。得到的结果为:

CenterNet根据自己的数据训练模型

四、使用已有的数据集进行训练

1、存放数据集

在CenterNet主目录的data目录下创建MyDataTest,如下图所示:

CenterNet根据自己的数据训练模型
annotations目录中存放第一步的json文件,比如我只有train.json文件

images目录存放7000张jpg文件

2、在CenterNet-master/src/lib/datasets/dataset/文件夹里面,复制coco.py并从命名为my_test.py

打开my_test.py修改:

line13:class COCO修改成class my_test

line14:num_classes = 6 #注意这里不包含背景类,只有6种船的类型

line15:default_resolution = [512, 512] 修改自己需要的训练图片大小,虽然我们的图片是1920*1080,但是无需修改

line16,18:均值方差改自己的,或者也可以不改

Line22:super(COCO, self).init()里面的COCO换成自己的类名my_test

Line23,24:修改自己的数据路径

CenterNet根据自己的数据训练模型

line26-37:修改自己json文件名:

CenterNet根据自己的数据训练模型

line39:类别名字和类别id改成自己的

CenterNet根据自己的数据训练模型

3、dataset_factory.py修改

将数据集加入CenterNet-master/src/lib/datasets/dataset_factory.py

Line14 添加:from .dataset.my_test import my_test

Line29添加: ‘my_test’:my_test

CenterNet根据自己的数据训练模型

4、/src/lib/opts.py修改

加入自己数据集

CenterNet根据自己的数据训练模型

line336: 修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

CenterNet根据自己的数据训练模型

5、CenterNet-master/src/lib/utils/debugger.py修改

Line 45添加:

CenterNet根据自己的数据训练模型

Line 458添加:

CenterNet根据自己的数据训练模型

6、训练数据

参数说明:

(1)arch:代表选择的backbone的类型

(2)img_size:控制图片长和宽

(3)lr和lr_step:控制学习率大小及变化

(4)batch_size:一个批次处理的图片个数

(5)num_epochs:学习数据集的总次数

(6)num_works:开启多少个线程加载数据集

在src目录下使用命令:

python main.py ctdet –dataset my_test –exp_id my_test –batch_size 4 –lr 0.001 –gpus 1 –num_workers 4

或者:nohup python main.py ctdet –dataset my_test –exp_id my_test –batch_size 4 –lr 0.001 –gpus 1 –num_workers 4 > nohup3.out 2>&1 &

默认是迭代140轮完成训练,如果嫌时间太久了,可以修改opts.py文件如下,指定运行5次就可以了,或者命令中带–num_epochs 5也是可以的。

CenterNet根据自己的数据训练模型

运行之后的日志如下图所示:

CenterNet根据自己的数据训练模型
CenterNet根据自己的数据训练模型

查看linux进程,会发现有5个正在跑的进程,因为指定了4个work同时进行训练,这4个work是多线程进行图片加载,另外1个是在训练模型。

CenterNet根据自己的数据训练模型

模型运行完毕之后,会在exp/ctet/my_test下生成两个pth模型文件。

7、生成训练的loss曲线图

按照上一步进行训练后,会在exp/ctdet/my_test/logs_xxx的目录下生成log.txt文件。

里面的数据只记录每一轮迭代完之后的loss信息,progressbar每一张的loss数据是不会写进log.txt文件的。

具体的信息如下图所示:

CenterNet根据自己的数据训练模型

然后就是读取上面的信息,生成loss的曲线图,参考代码如下:

import matplotlib.pyplot as plt

def plot_loss_curve(log_file):
loss_data = open(log_file)
all_lines = loss_data.readlines()
print(all_lines[4].split(' '))
total_loss = []
hm_loss = []
wh_loss = []
off_loss = []
val_loss = []
spend_time = []
num_lines = len(all_lines)

for line in range(num_lines):
total_loss1 = all_lines[line].split(' ')[4]
hm_loss1 = all_lines[line].split(' ')[7]
wh_loss1 = all_lines[line].split(' ')[10]
off_loss1 = all_lines[line].split(' ')[13]
spend_time1 = all_lines[line].split(' ')[16]

print(total_loss1)
print(spend_time1)

total_loss.append(float(total_loss1))
hm_loss.append(float(hm_loss1))
wh_loss.append(float(wh_loss1))
off_loss.append(float(off_loss1))
spend_time.append(float(spend_time1))

return total_loss

if __name__ == '__main__':
loss_res18 = plot_loss_curve("D:\\temp\\centernet_log.txt")
fig = plt.figure(figsize=(10, 4))
ax = fig.add_subplot(111)
ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1)
ax.set_xlim([1, 6])
ax.set_xticks(range(0, 5, 1))
ax.set_yticklabels(['jan', 'feb', 'mar'])
ax.set_xlabel('epochs')
ax.set_ylabel('loss_value')
ax.text(8750, 20, "plane", color='red')
ax.set_title('loss_of_CenterNet')
ax.legend(loc='best')
ax.grid()
plt.show()

生成的曲线如下图所示:

CenterNet根据自己的数据训练模型

8、图片目标检测操作

从网上找一张船的图片放到images目录下,然后运行如下命令:

python src/demo.py ctdet –demo images/001987.jpg –load_model /workspace/hugh/CenterNet-master/exp/ctdet/my_test/model_best.pth –vis_thresh 0.1

上一步模型训练只迭代了5轮,模型准确度是不高的,如果不设置vis_thresh会导致图片中检测不到目标。

CenterNet根据自己的数据训练模型

Original: https://blog.csdn.net/benben044/article/details/126525532
Author: benben044
Title: CenterNet根据自己的数据训练模型

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/681779/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球