MMCV之Runner介绍

文章目录

前言

mmcv/runner/base_runner.py文件中,定义了runner类。该类用于管理一个模型的训练和评估过程。这里放张官方示意图(runner简单来说就是实现了右边是个红色框的类):

MMCV之Runner介绍

; 1、BaseRunner类

该类是所有子runne的r基类,贴下最核心的代码(好多细节我给删除掉了,因为太多了):

class BaseRunner(metaclass=ABCMeta):
    """The base class of Runner, a training helper for PyTorch.

    All subclasses should implement the following APIs:

    - ()
    - ()
    - ()
    - save_checkpoint()

"""

    def __init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):
    self._hooks = []

    @abstractmethod
    def train(self):
        pass

    @abstractmethod
    def val(self):
        pass

    @abstractmethod
    def run(self, data_loaders, workflow, **kwargs):
        pass

    @abstractmethod
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl,
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        pass

    def register_hook(self, hook, priority='NORMAL'):
         self._hooks.insert(0, hook)

    def register_lr_hook(self, lr_config):
       self.register_hook(hook, priority='VERY_HIGH')

    def register_training_hooks(self,
                               lr_config,
                               optimizer_config=None,
                               checkpoint_config=None,
                               log_config=None,
                               momentum_config=None,
                               timer_config=dict(type='IterTimerHook'),
                               custom_hooks_config=None):
       self.register_lr_hook(lr_config)
       self.register_momentum_hook(momentum_config)
       self.register_optimizer_hook(optimizer_config)
       self.register_checkpoint_hook(checkpoint_config)
       self.register_timer_hook(timer_config)
       self.register_logger_hooks(log_config)
       self.register_custom_hooks(custom_hooks_config)

1) 初始化部分:包括(模型、批次数据、优化器、工作目录、meta(seed)和epoch数和iter数)。另外,值得注意的是,初始化了一个self.hooks列表,里面存储元素为Hook类实例出来的对象。
2) @abstractmethod:装饰器修饰了四个抽象方法:train、val、run和save_checkpoint。只要继承该类的子类必须实现这四个方法。
3) 注册hook函数:关于hook我会单独出一篇博文,只需知道大概路程即可。以lr_hook为例:首先输入参数lr_config传给register_training_hooks,之后函数内部调用register_lr_hook函数,将lr_config实例成对应的hook对象,最终调用register_hook函数:将lr_hook添加到self.hooks列表中。
接下来介绍两个子类的Runner。

2、EpochBasedRunner

不要说太多废话,把核心代码贴出来:

[En]

Don’t talk too much nonsense, post the core code:

@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.

"""

    def run_iter(self, data_batch, train_mode, **kwargs):

        if self.mode == 'train':
            outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        self.call_hook('before_run')
        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                epoch_runner = getattr(self, mode)
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        pass

EpochBasedRunner继承了BaseRunner类,故实现了四种方法。save_checkpoing不多说了,核心是run方法,内部思路就是开头的那张图,借助mode字段是’train’或者’val’去调用不同train方法或者val方法。而train和val内部调用run_iter方法执行一次迭代的前向传播计算。 该runner借助epoch来训练模型,是mmdet中最常用的runner。

3、IterBasedRunner

class IterLoader:

    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 0

    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            if hasattr(self._dataloader.sampler, 'set_epoch'):
                self._dataloader.sampler.set_epoch(self._epoch)
            time.sleep(2)
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)

        return data

    def __len__(self):
        return len(self._dataloader)

@RUNNERS.register_module()
class IterBasedRunner(BaseRunner):
    """Iteration-based Runner.

    This runner train models iteration by iteration.

"""

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._epoch = data_loader.epoch
        data_batch = next(data_loader)
        self.call_hook('before_train_iter')
        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._inner_iter += 1
        self._iter += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        data_batch = next(data_loader)
        self.call_hook('before_val_iter')
        outputs = self.model.val_step(data_batch, **kwargs)
        self.outputs = outputs
        self.call_hook('after_val_iter')
        self._inner_iter += 1

    def run(self, data_loaders, workflow, max_iters=None, **kwargs):
        self.call_hook('before_run')
        iter_loaders = [IterLoader(x) for x in data_loaders]
        self.call_hook('before_epoch')
        while self.iter < self._max_iters:
            for i, flow in enumerate(workflow):
                self._inner_iter = 0
                mode, iters = flow
                iter_runner = getattr(self, mode)
                for _ in range(iters):
                    if mode == 'train' and self.iter >= self._max_iters:
                        break
                    iter_runner(iter_loaders[i], **kwargs)
        time.sleep(1)
        self.call_hook('after_epoch')
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='iter_{}.pth',
                        meta=None,
                        save_optimizer=True,
                        create_symlink=True):
        pass

IterBaseRunner大同小异,同样实现了四个方法。唯一和EpochBaseRunner区别是没有实现run_iter方法。由于该runner以最大迭代轮数训练,故分别在train和val方法中实现了run_iter的计算。另外,多了一个IterLoader类,作用是当迭代完一个epoch后,重新遍历数据,此时用该类就可以用try-except实现重新迭代,可以看我的注释。

总结

本文介绍了mmcv中runner介绍,基本所有mmdet模型都用到上述两个runner。后续会更hook篇,敬请期待。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

Original: https://blog.csdn.net/wulele2/article/details/122148362
Author: 武乐乐~
Title: MMCV之Runner介绍

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/497061/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球