MMDetection框架入门教程(完全版)

网上MMDetection的教程看似有很多,但感觉都不成系统,看完一圈下来还是不知道MMDetection要怎么用。这里还是推荐直接跟着官方教程,结合源码学习MMDetection,相关链接汇总如下:

  1. 官方教程 – MMCV
  2. 官方教程 – MMDetection
  3. 官方教程 – 不得不知的 MMDetection 学习路线(个人经验版)
  4. 西安交大课件 – mmdetection教程(使用篇)

本文会介绍如何在MMDetection中从头开始搭建一套属于自己的算法。前几篇博客算是本人学习过程中的笔记,从源码本身分析了MMDetection的原理,比较细碎,本篇博客会从宏观的角度重新梳理一下MMDetection的使用方法以及流程原理,算是对之前一个月学习过程的总结。

  1. MMDetection框架入门教程(一):Anaconda3下的安装教程(mmdet+mmdet3d)
  2. MMDetection框架入门教程(二):快速上手教程
  3. MMDetection框架入门教程(三):配置文件详细解析
  4. MMDetection框架入门教程(四):注册机制详解
  5. MMDetection框架入门教程(五):Runner和Hook详细解析

; 1. 框架概述

MMDetection是商汤和港中文大学针对目标检测任务推出的一个开源项目,它基于Pytorch实现了大量的目标检测算法,把数据集构建、模型搭建、训练策略等过程都封装成了一个个模块,通过模块调用的方式,我们能够以很少的代码量实现一个新算法,大大提高了 代码复用率

整个MMLab家族除了MMDetection,还包含针对目标跟踪任务的MMTracking,针对3D目标检测任务的MMDetection3D等开源项目,他们都是以Pytorch和MMCV以基础。Pytorch不需要过多介绍,MMCV是一个面向计算机视觉的基础库,最主要作用是提供了 基于Pytorch的通用训练框架,比如我们常提到的Registry、Runner、Hook等功能都是在MMCV中支持的。另外,MMCV还提供了通用IO接口、多种CNN网络结构、高质量实现的常见CUDA算子,这里就不进一步展开了。

  1. 框架整体流程

2.1 Pytorch

我们使用Pytorch构建一个新算法时,通常包含如下几步:

  • 构建数据集:新建一个类,并继承 Dataset类,重写 __getitem__()方法实现数据和标签的加载和遍历功能,并以pipeline的方式定义数据预处理流程
  • 构建数据加载器:传入相应的参数实例化DataLoader
  • 构建模型:新建一个类,并继承 Module类,重写 forward()函数定义模型的前向过程
  • 定义损失函数和优化器:根据算法选择合适和损失函数和优化器
  • 训练和验证:循环从DataLoader中获取数据和标签,送入网络模型,计算loss,根据反传的梯度使用优化器进行迭代优化
  • 其他操作:在主调函数里可以任意穿插训练Tricks、日志打印、检查点保存等操作

2.2 MMDetection

使用Pytorch构建一个新算法时,通常包含如下几步:

  • 注册数据集CustomDataset是MMDetection在原始的 Dataset基础上的再次封装,其 __getitem__()方法会根据训练和测试模式分别重定向到 prepare_train_img()prepare_test_img()函数。用户以继承 CustomDataset类的方式构建自己的数据集时,需要重写 load_annotations()get_ann_info()函数,定义数据和标签的加载及遍历方式。完成数据集类的定义后,还需要使用 DATASETS.register_module()进行模块注册。
  • 注册模型:模型构建的方式和Pytorch类似,都是新建一个 Module的子类然后重写 forward()函数。唯一的区别在于MMDetection中需要继承 BaseModule而不是 ModuleBaseModuleModule的子类,MMLab中的任何模型都必须继承此类。另外,MMDetection将一个完整的模型拆分为backbone、neck和head三部分进行管理,所以用户需要按照这种方式,将算法模型拆解成3个类,分别使用 BACKBONES.register_module()NECKS.register_module()HEADS.register_module()完成模块注册。
  • 构建配置文件:配置文件用于配置算法各个组件的运行参数,大体上可以包含四个部分:datasets、models、schedules和runtime。完成相应模块的定义和注册后,在配置文件中配置好相应的运行参数,然后MMDetection就会通过 Registry类读取并解析配置文件,完成模块的实例化。另外,配置文件可以通过 _base_字段实现继承功能,以提高代码复用率。
  • 训练和验证:在完成各模块的代码实现、模块的注册、配置文件的编写后,就可以使用 ./tools/train.py./tools/test.py对模型进行训练和验证,不需要用户编写额外的代码。

2.3 流程对比

虽然从步骤上看MMDetection相比Pytorch的算法实现步骤存在挺大差异,但底层的逻辑实现和Pytorch本质上还是一样的,可以参考下图对照着进行理解,其中蓝色部分表示Pytorch流程,橙色部分表示MMDetection流程,绿色部分表示和算法框架无关的通用流程。

在开始接触MMDetection的算法实现流程之前,必须要先对注册机制和Hook机制有一个大致的了解,推荐先快速阅读,对注册机制和Hook机制先有一个大体上的了解,看完第五章后再回过头来看注册机制和Hook机制的细节部分会有更深的体会。

; 3. 注册机制

3.1 Registry类

MMDetection作为MMCV的下游项目,继承了MMCV的模块管理方式——注册机制。简单来说,注册机制就是维护几张查询表,key是模块的名称,value是模块的句柄,每张查询表都管理一批功能相似的不同模块。我们每新建一个模块,都要根据模块实现的功能将对应的 key-value查询对保存到对应的查询表中,这个保存的过程就称为” 注册“。当我们想要调用某个模块时,只需要根据模块名称从查询表中找到对应的模块句柄,然后就能完成模块初始化或方法调用等操作。MMCV通过 Registry类来实现字符串(key)到类(value)的映射。

Registry的构造函数如下所示,变量 self._module_dict就是上面提到的”查询表”,注册的模块都会存到这个字典类型的变量里,新建一个Registry实例就是新建一张查询表。另外,Registry还支持继承机制。

from mmcv.utils import Registry

class Registry:

    def __init__(self, name, build_func=None, parent=None, scope=None):

        self._name = name

        self._module_dict = dict()

        self._children = dict()

        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func

        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

模块的注册通过Registry的成员函数 register_module()来实现, register_module()内部又会调用另一个私有函数 _register_module(),模块注册的核心功能其实是在 _register_module()中实现的。核心代码也很简单,就是将传入的 module_namemodule_class保存到字典 self._module_dict中。

def _register_module(self, module_class, module_name=None, force=False):

    if module_name is None:
        module_name = module_class.__name__

    if isinstance(module_name, str):
        module_name = [module_name]

    for name in module_name:

        if not force and name in self._module_dict:
            raise KeyError(f'{name} is already registered in {self.name}')

        self._module_dict[name] = module_class

在我们通过字符串获取到一个模块的句柄后,可以通过 self.build_func函数句柄来实例化这个模块。 build_func可以人为指定,也可以从父类继承,一般来说都是默认使用 build_from_cfg()函数,即使用配置参数 cfg来初始化该模块。配置参数 cfg是一个字典,里面的 type字段是模块名称的字符串,其他字段则对应模块构造函数的输入参数。

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if isinstance(obj_type, str):

        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):

        obj_cls = obj_type
    else:
        raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')

    try:
        return obj_cls(**args)
    except Exception as e:
        raise type(e)(f'{obj_cls.__name__}: {e}')

考虑到 registry参数需要指向当前注册器本身,我们一般是调用Registry类的 build()方法而不是 self.build_func

def build(self, *args, **kwargs):
    return self.build_func(*args, **kwargs, registry=self)

下面是一个小例子,模拟了网络模型的注册和调用过程。注意一下,我们打印Registry对象时,实际上打印的是 self._module_dict中的values。


MODELS = Registry('myModels')

@MODELS.register_module()
class ResNet(object):
    def __init__(self, depth):
        self.depth = depth
        print('Initialize ResNet{}'.format(depth))

class FPN(object):
    def __init__(self, in_channel):
        self.in_channel= in_channel
        print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)

print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': , 'FPN': })
"""

backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)

my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256

"""

3.2 注册机制小结

注册机制是一种模块管理手段,按照不同的模块功能对模块进行分组管理,每个分组都由一张查询表维护,查询表记录了模块名称(字符串)到模块本身(本身)的映射关系,将映射关系记录到查询表的过程称为 “注册”。一旦模块完成注册,只要根据模块名称就能很方便的索引到具体的模块句柄,之后就可以按照正常程序流程,对模块进行初始化和使用。一个模块的注册到使用包含5个步骤:

  1. 新建一个类,实现自定义功能
  2. 将该类注册到对应的查询表中( register_module
  3. 在配置文件中指定该模块的初始化参数
  4. 通过build函数对模块进行实例化( build_from_cfg
  5. 使用该实例对象执行功能函数

  6. Hook机制

4.1 Hook类

MMDetection的整个算法过程就像一个黑盒子:给定输入后(配置文件),黑盒子就会吐出算法结果。整个过程封装度非常高,几乎不需要手写什么代码,但是我们如何在算法执行过程中加入自定义操作呢?这就是Hook机制的作用。

简单来说, Hook可以理解为一种触发器,可以在程序预定义的位置执行预定义的函数。MMCV根据算法的生命周期预定义了6个可以插入自定义函数的位点,用户可以在每个位点自由地插入任意数量的函数操作,如下图所示:

这6个位置基本涵盖了自定义操作可能出现的位置,MMCV已经实现了部分常用Hook,其中默认Hook不需要用户自行注册,通过配置文件配置对应的参数即可;定制Hook则需要用户在配置文件中手动配置 custom_hooks字段进行注册。

Hook类本身代码很少,只提供了预定义位置的接口函数,任何自定义的Hook都需要继承 Hook类,然后根据需要重写对应的接口函数。比如检查点保存操作通常发生在每次迭代或epoch后,所以我们需要重写 after_train_iterafter_train_epoch

class Hook:
    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

@HOOKS.register_module()
class CheckpointHook(Hook):
    def __init__(self,
                 interval=-1,
                 by_epoch=True,
                 save_optimizer=True,
                 out_dir=None,
                 max_keep_ckpts=-1,
                 **kwargs):
        ...
    def after_train_iter(self, runner):
        ...
    def after_train_epoch(self, runner):
        ...

和其他模块不同,当我们定义好一个Hook(并注册到 HOOKS注册器中)之后,还需要注册到Runner中才能使用, 前后一共进行两次注册。第一次注册到 HOOKS是为了程序能够根据Hook名称找到对应的模块,第二次注册到Runner中是为了程序执行到预定义位置时能够调用对应的函数。

Runner是MMCV用来管理训练过程的一个类,它内部会维护一个list类型变量 self._hooks,我们需要把训练过程会调用的Hook 实例对象按照优先级顺序全部添加到 self._hooks中,这个过程通过 Runner.register_hook()函数实现。MMCV预定义了几种优先级, 数字越小表示优先级越高, 如果觉得默认的分级方式颗粒度过大, 也 可以直接传入0~100的整数进行精细划分。

def register_hook(self, hook, priority='NORMAL'):
    """预定义优先级
    +--------------+------------+
    | Level        | Value      |
    +==============+============+
    | HIGHEST      | 0          |
    +--------------+------------+
    | VERY_HIGH    | 10         |
    +--------------+------------+
    | HIGH         | 30         |
    +--------------+------------+
    | ABOVE_NORMAL | 40         |
    +--------------+------------+
    | NORMAL       | 50         |
    +--------------+------------+
    | BELOW_NORMAL | 60         |
    +--------------+------------+
    | LOW          | 70         |
    +--------------+------------+
    | VERY_LOW     | 90         |
    +--------------+------------+
    | LOWEST       | 100        |
    +--------------+------------+
"""
    hook.priority = priority

    inserted = False
    for i in range(len(self._hooks) - 1, -1, -1):
        if priority >= self._hooks[i].priority:
            self._hooks.insert(i + 1, hook)
            inserted = True
            break
    if not inserted:
        self._hooks.insert(0, hook)

将Hook实例加入到 self._hooks中之后,然后就可以在预定义位置调用 call_hook()来调用各个Hook实例中的对应方法。 call_hook()称为回调函数。


self.call_hook('after_train_epoch')

while self.epoch < self._max_epochs:

    self.call_hook('before_train_epoch')

    for i, data_batch in enumerate(self.data_loader):

        self.call_hook('before_train_iter')

        self.model.train_step()

        self.call_hook('after_train_iter')

    self.call_hook('after_train_epoch')

self.call_hook('after_train_epoch')

调用 call_hook()时会遍历 self._hooks中所有Hook实例,并根据 fn_name调用Hook实例的指定成员函数。比如 fn_name='before_train_epoch'时, call_hook()会挨个调用所有Hook的 before_train_epoch()函数。而且由于 self._hooks已经按照优先级进行过排序, call_hook()会先调用优先级高的Hook方法。

def call_hook(self, fn_name):
    for hook in self._hooks:
        getattr(hook, fn_name)(self)

4.2 Hook机制小结

Hook是一种设置在程序固定位置的触发器,当程序执行到预设位点时则会触发断点,执行Hook函数的流程,结束后再回到断点位置继续执行主流程的代码。实现一个Hook包含5个步骤:

  1. 定义一个类,继承Hook基类
  2. 根据自定义Hook的功能有选择地重写Hook基类中对应的函数
  3. 注册自定义Hook模块到HOOKS查询表中( register_module
  4. 实例化Hook模块并注册到Runner中( register_hook
  5. 使用回调函数调用重写的Hook函数( call_hook

  6. 算法实现流程

2.2节提到,使用MMDetection实现一个新算法,包含注册数据集、注册模型、构建配置文件、训练/验证这四个步骤。要理解MMDetection的算法实现流程, 必须要吃透Config、Registry、Runner和Hook这四个类

5.1 注册数据集

定义自己的数据集时,需要新写一个继承 CustomDataset的Dataset类,然后重写 load_annotations()函数和 get_ann_info()函数。官方文档上说,用户如果要使用 CustomDataset,要将现有数据集转换成MMDetection兼容的格式(COCO格式或中间格式) 。但我看了一下底层的代码并没有发现有这个限制,只要你的数据格式能和你实现的 load_annotations()get_ann_info()对应上即可。

"""
中间数据格式:
[
    {
        'filename': 'a.jpg',                                        # 图片路径
        'width': 1280,                                              # 图片尺寸
        'height': 720,
        'ann': {                                                    # 标注信息
            'bboxes':  (n, 4),                  # 标注框坐标(x1, y1, x2, y2)
            'labels':  (n, ),                   # 标注框类别
            'bboxes_ignore':  (k, 4),           # 不关注的标注框坐标(可选)
            'labels_ignore':  (k, )                 # 不关注的标注框类别(可选)
        }
    },
    ...

]
"""

class CustomDataset(Dataset):
    CLASSES = None
    def __init__(self,
                 ann_file,
                 pipeline,
                 classes=None,
                 data_root=None,
                 img_prefix='',
                 seg_prefix=None,
                 proposal_file=None,
                 test_mode=False,
                 filter_empty_gt=True):
        self.ann_file = ann_file
        self.data_root = data_root
        self.img_prefix = img_prefix
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
        self.filter_empty_gt = filter_empty_gt
        self.CLASSES = self.get_classes(classes)

        self.data_infos = self.load_annotations(self.ann_file)

        if not test_mode:
            valid_inds = self._filter_imgs()
            self.data_infos = [self.data_infos[i] for i in valid_inds]

        self.pipeline = Compose(pipeline)

在Pytorch中 Dataset的遍历是通过重写 __getitem__()函数实现的,但MMDetection的 CustomDataset虽然是 Dataset的子类,却没有要求我们重写 __getitem__()函数,原因是为了方便训练模式和测试模式下的数据管理,MMDetection已经重写了 __getitem__()函数,可以根据当前运行模式调用 prepare_train_img()prepare_test_img(),两者的区别在于是否加载训练标签。所以我们只需要重写 load_annotations()get_ann_info()函数,剩下的部分交给MMDetection就可以了。

def __getitem__(self, idx):
    if self.test_mode:
        return self.prepare_test_img(idx)
    else:
        return self.prepare_train_img(idx)

def prepare_train_img(self, idx):
    img_info = self.data_infos[idx]

    ann_info = self.get_ann_info(idx)
    results = dict(img_info=img_info, ann_info=ann_info)
    return self.pipeline(results)

def prepare_test_img(self, idx):
    img_info = self.data_infos[idx]
    results = dict(img_info=img_info)
    return self.pipeline(results)

完成自定义的Dataset类后别忘记加上 @DATASETS.register_module()将当前模块注册到DATASETS表中。

5.2 注册模型

网络模型的定义比较简单,相比Pytorch只有两个区别:

  1. 继承的父类从 Module变成了 BaseModule
  2. 需要按照backbone、neck和head的结构将模型拆解成3个部分,分别定义并注册到 BACKBONESNECKS以及 HEADS当中。

5.3 构建配置文件

2.2节有提到,在MMDetection框架下,我们不需要另外实现迭代训练/测试流程的代码,只需要执行现成的train.py或test.py即可。但MMDetection怎么知道我们需要哪些模块呢?这就是配置文件起到的作用。

5.3.1 配置文件的构成

配置文件是由一系列变量定义组成的文本文件,其中 dict类型的变量表示一个个的模块, dict变量必须包含 type字段,表示模块名称, 其它字段则和模块构造函数的参数一一对应,届时用于该模块的初始化(见第本文3章的 build_from_cfg()函数)。该模块必须是已经注册的,否则后续MMDetection无法根据 type值找到对应的模块。配置文件除了 dict类型的变量以外,还可以是其他任意类型,一般是辅助 dict变量定义的中间变量,比如:

test_pipeline = [
    dict(type='LoadMultiViewImageFromFiles', to_float32=True),
    dict(type='NormalizeMultiviewImage', **img_norm_cfg),
    dict(type='PadMultiViewImage', size_divisor=32)
]
evaluation = dict(interval=2, pipeline=test_pipeline)

配置文件也支持继承操作,通过 _base_变量来实现。 _base_是一个 list类型变量,里面存储的是要继承的配置文件的路径。在解析配置文件的时候,文件解析器以递归的方式(其他配置文件也可能包含 _base_变量)解析所有配置文件。任何配置文件往上追溯都会继承以下四个文件,分别对应数据集(datasets)、模型(models)、训练策略(schedules)和运行时的默认配置(default_runtime):

_base_ = [
    'mmdetection/configs/_base_/models/fast_rcnn_r50_fpn.py',
    'mmdetection/configs/_base_/datasets/coco_detection.py',
    'mmdetection/configs/_base_/schedules/schedule_1x.py',
    'mmdetection/configs/_base_/default_runtime.py',
]

如果你对上面继承这4个基础配置文件的配置文件进行打印,可以看到如下内容,这也是 任何一个完整配置文件都应该包含的配置信息。当然,你也可以任意增加自定义的配置信息。所以我们平常新建一个配置文件的时候,一般都是继承这4个基础配置文件,然后在此基础上进行针对性调整。


model = dict(
    type='FastRCNN',
    backbone=dict(
        type='ResNet',
        ...,
    ),
    neck=dict(
        type='FPN',
        ...,
    ),
    roi_head=dict(
        type='StandardRoIHead',
        ...,
        loss_cls=dict(...),
        loss_bbox=dict(...),
    ),
    train_cfg=dict(
        assigner=dict(...),
        sampler=dict(...),
        ...
    ),
    test_cfg =dict(
        nms=dict(...),
        ...,
    )
)

dataset_type = '...'
data_root = '...'
img_norm_cfg = dict(...)
train_pipeline = [
    ...,
]
test_pipeline = [...]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipline=trian_pipline,
    ),
    val=dict(
        ...,
        pipline=test_pipline,
    ),
    test=dict(
        ...,
        pipline=test_pipline,
    )
)

evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)

checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]

另外还有一些可选的配置参数,比如 custom_imports,用于导入用户自定义的模块,当配置文件解析器解析到该字段时,会调用 import_modules_from_strings()函数将字段 imports包含的模块导入到程序中。

custom_imports = dict(imports=['os.path', 'numpy'],
                      allow_failed_imports=False)

5.3.2 配置文件的修改

修改配置文件时会遇到2种情况:

  1. 修改已有dict的某个参数:直接重写对应的参数
  2. 需要删掉原有dict的所有参数,然后用一组全新的参数代替:增加 _delete_=True字段

以修改学习率和更换优化器为例解释这两种情况下应该怎么修改配置文件:


optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

optimizer = dict(lr=0.001)

optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)

optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001)

optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)

5.3.3 配置文件的解析

解析配置文件其实是train.py和test.py要做的事,这里放到和构建配置文件一块讲了,逻辑上会更通畅一些。

一般使用Config类来管理配置文件。使用 Config.fromfile(filename)来读取配置文件(也可以直接传入一个dict),返回一个Config类实例cfg,然后可以通过 print(cfg.pretty_text)的方式来打印配置文件信息,或者通过 cfg.dump(filepath)来保存配置文件信息。

from mmcv import Config

cfg = Config.fromfile('../configs/test_config.py')

fromfile()函数源码如下,其核心函数是 _file2dict()_file2dict()会根据文本顺序,按照key = value的格式解析配置文件,得到一个名为 cfg_dict的字典,如果存在 _base_字段,还会对 _base_包含的每个文件路径再调用一次 _file2dict()函数,将文件中包含的配置参数加入到 cfg_dict中,实现配置文件的 继承功能。需要注意的是, _file2dict()内部会对 _base_中不同文件包含的键值进行校验, 不同基础配置文件中不允许出现重复的键值,否则Config不知道以哪个配置文件为准。

def fromfile(filename,
             use_predefined_variables=True,
             import_custom_modules=True):
    cfg_dict, cfg_text = Config._file2dict(filename,
                                           use_predefined_variables)

    if import_custom_modules and cfg_dict.get('custom_imports', None):
        import_modules_from_strings(**cfg_dict['custom_imports'])
    return Config(cfg_dict, cfg_text=cfg_text, filename=filename)

调用 _file2dict()解析得到的 cfg_dict格式如下,配置文件中的文本信息全部转换成了变量存储在一个字典类型之中。

另外有两点需要补充一下,其一是构造Config对象的时候,会将python的 dict数据类型转换为 ConfigDict类型进行处理。 ConfigDict是第三方库addict中 Dict的子类(Dict又是python dict的子类),因为python原生的 dict类型不支持 .&#x5C5E;&#x6027;的访问方式,特别是 dict内部嵌套了多层dict的时候,如果按照key的访问方式,代码写起来非常低效,而 Dict类通过重写 __getattr__()的方式实现了 .&#x5C5E;&#x6027;的访问方式。所以继承了 DictConfigDict也支持使用 .&#x5C5E;&#x6027;的方式访问字典中的各个成员值。

from mmcv import ConfigDict

model = ConfigDict(dict(backbone=dict(type='ResNet', depth=50)))

print(model.backbone.type)

其二,为了兼容配置文件名中出现小数点的情况, _file2dict()会在C盘下创建一个临时文件夹进行操作,如果C盘有访问权限设置,可能会出现报错,不过这个问题只会出现在Windows系统下。

5.3.4 配置文件小结

简单回顾一下,配置文件是一个包含多个 dict变量的文本文件,每个 dict对应一个具体的模块(该模块必须已经注册), dict必须要有 type字段,其他字段和该模块的构造参数相对应。当对调用 build()函数对模块进行实例化的时候,会根据 type字符串的值从查询表中找到对应的模块句柄,并使用 dict中其他字段的值作为构造参数对该模块进行初始化。

5.4 训练和测试

用MMDetection实现一个算法包含四个步骤,第一第二步注册数据集和模型的目的是构建基础模块(数据流和模型),第三步构建配置文件的目的是指定需要的模块以及模块的输入参数,接下来第四步就是根据配置文件把事先定义好的模块一个个拎出来,传入指定的输入参数,然后按照算法流程依次串起来。

5.4.1 train.py文件

我们先过一遍官方提供的 train.py代码(我只保留了核心功能代码),然后再介绍MMDetection是如何使用Runner和Hook来调度整个训练流程的,这样理解起来会更快一些。

train.py的主调函数做了4件事情,一个是利用Config类对我们第三步构建好的配置文件进行解析,然后对模型和数据集进行初始化,最后将模型和数据集传入 train_detector()函数,准备开始训练流程。

def main():

    cfg = Config.fromfile(args.config)

    model = build_detector(cfg.model)

    model.init_weights()

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))

    train_detector(model, datasets, cfg)

train_detector()函数主要是构建了dataloader,初始化了优化器以及runner和hooks,最后调用runner.run开始正式的迭代训练流程。其中涉及到了Runner的概念,不过这里先不展开,我们只要知道Runner也是一个模块,负责模型的迭代训练。

def train_detector(model, dataset, cfg):

    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner['type']

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            runner_type=runner_type) for ds in dataset
    ]

    model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(model=model, optimizer=optimizer)

    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))

     if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        for hook_cfg in cfg.custom_hooks:
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

    if cfg.resume_from:

        runner.resume(cfg.resume_from)
    elif cfg.load_from:

        runner.load_checkpoint(cfg.load_from)

    runner.run(data_loaders, cfg.workflow)

虽然官方 train.py文件写的很长,但是把核心代码扒出来一看,其实都是我们在Pytorch中熟悉的操作。整个train.py的流程如下图所示

  1. 首先解析传入的配置文件,并实例化配置文件中的各个模块;
  2. 然后使用datasets构造 data_loader,这里model使用了MMDataParallel进行了一层封装,主要是为了后续进行分布式训练;
  3. 然后使用data_loader和optimizer初始化一个Runner类对象 runner
  4. 注册训练过程中需要使用的Hooks
  5. 根据配置文件指定的工作流 workflow执行runner.run()函数进行迭代训练

下面对runner.run()函数内部进行展开介绍。

5.4.2 Runner类

Runner分为 EpochBasedRunnerIterBasedRunner,顾名思义,前者以epoch的方式管理流程,后者以iter的方式管理流程,它们都是BaseRunner的子类。EpochBasedRunner和IterBasedRunner本身没有重写构造函数,直接继承了BaseRunner的构造函数:

class BaseRunner(metaclass=ABCMeta):
    def __init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):

BaseRunner的任何子类都需要实现 run()train()val()save_checkpoint()四个方法,这也是Runner的核心方法,接下来就以EpochBasedRunner类为例对这四个函数进行详细分析。

run()函数
run()是Runner类的主调函数,会根据workflow指定的工作流,对data_loaders中的数据进行处理。目前MMCV支持训练和验证两种工作流,对于EpochBasedRunner而言,workflow配置为 [('train', 2)&#xFF0C;('val', 1)]表示先训练2个epoch,然后验证一个epoch; [('train', 1)]表示只进行训练,不进行验证。如果是IterBasedRunner, [('train', 2)&#xFF0C;('val', 1)]则表示先训练2个iter,然后验证一个iter。然后 getattr(self, mode)会根据不同mode调用self.train()函数和self.val()函数。

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
    while self.epoch < self._max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow

            epoch_runner = getattr(self, mode)

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= self._max_epochs:
                    break

                epoch_runner(data_loaders[i], **kwargs)

train()和val()函数
train()val()函数循环调用 run_iter()完成一个epoch流程。函数开头的self.model.train()和self.model.eval()实际上调用的是torch.nn.module.Module的成员函数,将当前模块设置为训练模式或验证模式,两种不同模式下batchnorm、dropout等层的操作会有区别。然后由于测试过程不需要梯度回传,所以val函数加了一个装饰器 @torch.no_grad()

def train(self, data_loader, **kwargs):

    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)
    for i, data_batch in enumerate(self.data_loader):
        self.run_iter(data_batch, train_mode=True, **kwargs)
        self._iter += 1

    self._epoch += 1

@torch.no_grad()
def val(self, data_loader, **kwargs):

    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    for i, data_batch in enumerate(self.data_loader):
        self.run_iter(data_batch, train_mode=False)

def run_iter(self, data_batch, train_mode, **kwargs):
    if self.batch_processor is not None:
        outputs = self.batch_processor(self.model, data_batch, train_mode=train_mode, **kwargs)
    elif train_mode:
        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
    else:
        outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)

    self.outputs = outputs

train()val()的核心函数是 run_iter(),根据train_mode参数调用 model.train_step()model.val_step(),这两个函数最终都会指向我们自己模型的 forward()函数,返回模型的前向推理结果(一般是Loss值)。Runner到我们自己的模型中间还会经过MMDataParallel、BaseDetector、SingleStageDetector(或TwoStageDetector)四个类,最终调用我们自己模型的 forward()函数,执行推理过程。

细心的同学可能会问,怎么从头到尾没看到梯度反传优化这一步骤?MMDetection的梯度优化是通过一个实现了 after_train_iter()的Hook实现的,其优先级为ABOVE_NORMAL。

@HOOKS.register_module()
class OptimizerHook(Hook):
    def after_train_iter(self, runner):
        runner.optimizer.zero_grad()
        runner.outputs['loss'].backward()
        if self.grad_clip is not None:
            grad_norm = self.clip_grads(runner.model.parameters())
            if grad_norm is not None:

                runner.log_buffer.update({'grad_norm': float(grad_norm)},
                                         runner.outputs['num_samples'])
        runner.optimizer.step()

save_checkpoint()函数
save_checkpoint()函数比较简单,就不过多说明了,最终是调用torch.save将检查点按下列格式保存成文件。

checkpoint = {
              'meta': dict(),
              'state_dict': dict(),
              'optimizer': dict())
}

Original: https://blog.csdn.net/qq_16137569/article/details/121316235
Author: Maples丶丶
Title: MMDetection框架入门教程(完全版)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/623920/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球