基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

加我微信拉你进群交流:wu331376411

一 修改背景

基于yoloV5系列越来越强大,适用面越来越广泛,主要是由于训练简单,模型适配性好,推理速度快等优点,yoloV5系列适用非常广泛。
但随着越发强大的系统,导致模型堆叠问题越发严重,输入相同的图片检测的内容不同,或者输入不同的图片检测类似的内容。这些都需要使用多个模型来完成,导致设备负载大,推理堆叠。实际运用场景可能有:多国车牌,使用不同的国家字符,需要用多个对应国家的模型来完成车牌文字检测识别,又比如:ADAS系统,输入相同的图像,不仅仅要检测前方的车辆类型,交通标志,车道线(YOLOP)等等。诸如需求比比皆是,故此在官方的模型上使其共用backbone,使用不同的检测头来完成相对于效果。

二 修改思路

共用backbone,使用多个检测头来分别检测不同国家的车牌。
比如我们定义第一个头是:大陆车牌,第二个头是:港澳车牌,第三个头是:老挝车牌等等。
重点 : 我们创建了多头,但是每次我们输入的图片只是其中一个头的,如果每个头都运行,会很浪费时间,所以我们只运行对应的一个头,这里就需要后期建立一个多头的列表,选择我们数据输入的对应头就OK了。


nc1 : 20
nc2 : 30
nc3 : 40

nc: [nc1,nc2,nc3]
depth_multiple: 0.33
width_multiple: 0.50
anchors:
  - [10,13, 16,30, 33,23]
  - [30,61, 62,45, 59,119]
  - [116,90, 156,198, 373,326]

backbone:

  [[9, 1, Conv, [64, 6, 2, 2]],
   [-1, 1, Conv, [128, 3, 2]],
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],
  ]

head1:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],
   [-1, 3, C3, [512, False]],

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],
   [-1, 3, C3, [256, False]],

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],
   [-1, 3, C3, [512, False]],

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],
   [-1, 3, C3, [1024, False]],

   [[17, 20, 23], 1, Detect, [nc1, anchors]],

  ]

head2:
  [[9, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],
   [-1, 3, C3, [512, False]],

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],
   [-1, 3, C3, [256, False]],

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 29], 1, Concat, [1]],
   [-1, 3, C3, [512, False]],

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 25], 1, Concat, [1]],
   [-1, 3, C3, [1024, False]],

   [[32, 35, 38], 1, Detect, [nc2, anchors]],
  ]

head3:
  [ [ 9, 1, Conv, [ 512, 1, 1 ] ],
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 6 ], 1, Concat, [ 1 ] ],
    [ -1, 3, C3, [ 512, False ] ],

    [ -1, 1, Conv, [ 256, 1, 1 ] ],
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
    [ [ -1, 4 ], 1, Concat, [ 1 ] ],
    [ -1, 3, C3, [ 256, False ] ],

    [ -1, 1, Conv, [ 256, 3, 2 ] ],
    [ [ -1, 44 ], 1, Concat, [ 1 ] ],
    [ -1, 3, C3, [ 512, False ] ],

    [ -1, 1, Conv, [ 512, 3, 2 ] ],
    [ [ -1, 40 ], 1, Concat, [ 1 ] ],
    [ -1, 3, C3, [ 1024, False ] ],

    [ [ 47, 50, 53 ], 1, Detect, [ nc3, anchors ] ],
  ]

注意: 每一层的连接方式需要修正,需要看是层的索引值。

基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

                 from  n    params  module                                  arguments
  0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]
  2                -1  1     18816  models.common.C3                        [64, 64, 1]
  3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]
  4                -1  2    115712  models.common.C3                        [128, 128, 2]
  5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]
  6                -1  3    625152  models.common.C3                        [256, 256, 3]
  7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]
  8                -1  1   1182720  models.common.C3                        [512, 512, 1]
  9                -1  1    656896  models.common.SPPF                      [512, 512, 5]
 10                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]
 11                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 12           [-1, 6]  1         0  models.common.Concat                    [1]
 13                -1  1    361984  models.common.C3                        [512, 256, 1, False]
 14                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]
 15                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 16           [-1, 4]  1         0  models.common.Concat                    [1]
 17                -1  1     90880  models.common.C3                        [256, 128, 1, False]
 18                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]
 19          [-1, 14]  1         0  models.common.Concat                    [1]
 20                -1  1    296448  models.common.C3                        [256, 256, 1, False]
 21                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]
 22          [-1, 10]  1         0  models.common.Concat                    [1]
 23                -1  1   1182720  models.common.C3                        [512, 512, 1, False]
 24      [17, 20, 23]  1     67425  Detect                                  [20, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
 25                 9  1    131584  models.common.Conv                      [512, 256, 1, 1]
 26                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 27           [-1, 6]  1         0  models.common.Concat                    [1]
 28                -1  1    361984  models.common.C3                        [512, 256, 1, False]
 29                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]
 30                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 31           [-1, 4]  1         0  models.common.Concat                    [1]
 32                -1  1     90880  models.common.C3                        [256, 128, 1, False]
 33                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]
 34          [-1, 29]  1         0  models.common.Concat                    [1]
 35                -1  1    296448  models.common.C3                        [256, 256, 1, False]
 36                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]
 37          [-1, 25]  1         0  models.common.Concat                    [1]
 38                -1  1   1182720  models.common.C3                        [512, 512, 1, False]
 39      [32, 35, 38]  1     94395  Detect                                  [30, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
 40                 9  1    131584  models.common.Conv                      [512, 256, 1, 1]
 41                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 42           [-1, 6]  1         0  models.common.Concat                    [1]
 43                -1  1    361984  models.common.C3                        [512, 256, 1, False]
 44                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]
 45                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 46           [-1, 4]  1         0  models.common.Concat                    [1]
 47                -1  1     90880  models.common.C3                        [256, 128, 1, False]
 48                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]
 49          [-1, 44]  1         0  models.common.Concat                    [1]
 50                -1  1    296448  models.common.C3                        [256, 256, 1, False]
 51                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]
 52          [-1, 40]  1         0  models.common.Concat                    [1]
 53                -1  1   1182720  models.common.C3                        [512, 512, 1, False]
 54      [47, 50, 53]  1    121365  Detect                                  [40, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
self.headi 24
self.headi 39
self.headi 54
Model Summary: 508 layers, 12958705 parameters, 12958705 gradients

三 模型修改

1 网络结构修改

修改结构的时候主要需要注意,这里是多个头,我们创建一个多头列表,输入对应头数据来完成模型训练即可。

  • 修改模型初始化
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
    主要是需要记录头的数量,骨干网络的层数,不同头的层数(列表)
    头的数量可以根据索引来进行输入数据,训练对应的头,推理的时候也是对应输入头的索引即可,
    骨干是共用的,所以记录数量,后期好用于网络结构拼接。
    不同的头可以使用不同的层数,针对难度大的数据可以使用较多的卷积,默认是15层。
  • 初始化detect层
    detect层的m.stride值,默认是[8,16,32]。由于都有不同的头,anchor对应的下采样比例可能出现不一样,可能需要使用不同的anchor来进行初始化,所以这里每个头的m.stride 都需要进行初始化。用一个循环完成。
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
  • 网络拼接
    网络拼接的时候需要主要,共用主干后,对应的值会有一些变话,都可以更具传入的头和对应头的层数进行查询,这里的计算大家可以自己算一下,需要注意的是P4,P5拼接的层数是头数量15的倍数,_forward_once函数中,新加代码乘以15的由来。
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

还有一些小的修改,大家可以自己查看yolo_plate.py文件。基本都是和输入头索引对应的detect层的位置,也就是前面计算的 self.headi_forward

; 2数据读取修改

  • 修改数据读取配置文件
    添加 headnum 头的数量,用于数据读取的循环值。
    依次写个头的对应的数据路径,类别,以及类别名称即可。
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
  • 数据读取成dataloader
    这里是多个头的数据,所以创建的时候使用列表来进行存储
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
    修改create_dataloader方法,返回列表值即可
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
    这里需要记录类别数量,名字等等对应即可,修改较为简单,省去,不清楚的可以去查看源码。
  • 数据训练数据读取
    这里我们创建了多个头的数据dataloader,我们训练的时候是同时进行训练的,所以每次从一个dataloader中读取相同张数的数据,进行一个batch训练,然后将loss进行相加然后回传。
    由于数据长短不同,所以我们按照最长的数据进行设置一个epoch的长度,如果短的读取完了,再次创建train_loader来进行重复读取训练。
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
    数据运行逻辑:
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

3 训练工程常见问题修改

  • 根据检测头的数量修改读取数据的路径:
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
  • general.py文件,修改读取数量路径,修改为列表形式。
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
  • 根据数量dataloader 读取对应的bar数据读取器,列表形式
    基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

; 四 模型训练

我使用416大小训练了2个头的内容,map涨点很快,训练速度和之前的训练过程相当,稍微慢一丢丢


python train_plate.py --data data/mydata.yaml --batch 256 --epochs 400 --weights weights/yolov5s.pt   --imgsz 416  --device '0,1'  --cfg models/yolov5s_plate.yaml  --hyp data/hyps/palte_head.yaml --name car_plate_head_size416

基于yoloV5-v6分类多检测头模型修改(多国车牌检测)
模型收敛的比单个头训练的更快一些。

五 模型开源

目前还有一些内容没有更新完成,完成后上传github

Original: https://blog.csdn.net/small_wu/article/details/127084546
Author: 五小白
Title: 基于yoloV5-v6分类多检测头模型修改(多国车牌检测)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/662438/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球