文章目录
前言
mmcv/runner/base_runner.py文件中,定义了runner类。该类用于管理一个模型的训练和评估过程。这里放张官方示意图(runner简单来说就是实现了右边是个红色框的类):
; 1、BaseRunner类
该类是所有子runne的r基类,贴下最核心的代码(好多细节我给删除掉了,因为太多了):
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
-
()
-
()
-
()
-
save_checkpoint()
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
self._hooks = []
@abstractmethod
def train(self):
pass
@abstractmethod
def val(self):
pass
@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
pass
@abstractmethod
def save_checkpoint(self,
out_dir,
filename_tmpl,
save_optimizer=True,
meta=None,
create_symlink=True):
pass
def register_hook(self, hook, priority='NORMAL'):
self._hooks.insert(0, hook)
def register_lr_hook(self, lr_config):
self.register_hook(hook, priority='VERY_HIGH')
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
1) 初始化部分:包括(模型、批次数据、优化器、工作目录、meta(seed)和epoch数和iter数)。另外,值得注意的是,初始化了一个self.hooks列表,里面存储元素为Hook类实例出来的对象。
2) @abstractmethod:装饰器修饰了四个抽象方法:train、val、run和save_checkpoint。只要继承该类的子类必须实现这四个方法。
3) 注册hook函数:关于hook我会单独出一篇博文,只需知道大概路程即可。以lr_hook为例:首先输入参数lr_config传给register_training_hooks,之后函数内部调用register_lr_hook函数,将lr_config实例成对应的hook对象,最终调用register_hook函数:将lr_hook添加到self.hooks列表中。
接下来介绍两个子类的Runner。
2、EpochBasedRunner
不要说太多废话,把核心代码贴出来:
[En]
Don’t talk too much nonsense, post the core code:
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
if self.mode == 'train':
outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2)
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2)
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
epoch_runner = getattr(self, mode)
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1)
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
pass
EpochBasedRunner继承了BaseRunner类,故实现了四种方法。save_checkpoing不多说了,核心是run方法,内部思路就是开头的那张图,借助mode字段是’train’或者’val’去调用不同train方法或者val方法。而train和val内部调用run_iter方法执行一次迭代的前向传播计算。 该runner借助epoch来训练模型,是mmdet中最常用的runner。
3、IterBasedRunner
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2)
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
@RUNNERS.register_module()
class IterBasedRunner(BaseRunner):
"""Iteration-based Runner.
This runner train models iteration by iteration.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader)
self.call_hook('before_train_iter')
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
self.outputs = outputs
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
data_batch = next(data_loader)
self.call_hook('before_val_iter')
outputs = self.model.val_step(data_batch, **kwargs)
self.outputs = outputs
self.call_hook('after_val_iter')
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders]
self.call_hook('before_epoch')
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
iter_runner = getattr(self, mode)
for _ in range(iters):
if mode == 'train' and self.iter >= self._max_iters:
break
iter_runner(iter_loaders[i], **kwargs)
time.sleep(1)
self.call_hook('after_epoch')
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
pass
IterBaseRunner大同小异,同样实现了四个方法。唯一和EpochBaseRunner区别是没有实现run_iter方法。由于该runner以最大迭代轮数训练,故分别在train和val方法中实现了run_iter的计算。另外,多了一个IterLoader类,作用是当迭代完一个epoch后,重新遍历数据,此时用该类就可以用try-except实现重新迭代,可以看我的注释。
总结
本文介绍了mmcv中runner介绍,基本所有mmdet模型都用到上述两个runner。后续会更hook篇,敬请期待。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。
Original: https://blog.csdn.net/wulele2/article/details/122148362
Author: 武乐乐~
Title: MMCV之Runner介绍
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/497061/
转载文章受原作者版权保护。转载请注明原作者出处!