MMDet逐行解读之AnchorGenerator

文章目录

前言

本篇主要介绍mmdet/core/anchor/anchor_generator.py文件下的AnchorGenerator类。以RetinaNet的配置作为说明。

anchor_generator_cfg = dict(
    type='AnchorGenerator',
    octave_base_scale=4,
    scales_per_octave=3,
    ratios=[0.5, 1.0, 2.0],
    strides=[8, 16, 32, 64, 128])

1、base_anchors的生成

所谓base_anchors是在初始化AnchorGenerator类借助gen_base_anchors方法产生了基础的9个anchor,这些anchor是原图上的anchor。

@ANCHOR_GENERATORS.register_module()
class AnchorGenerator(object):
    def __init__(self,
                 strides,
                 ratios,
                 scales=None,
                 base_sizes=None,
                 scale_major=True,
                 octave_base_scale=None,
                 scales_per_octave=None,
                 centers=None,
                 center_offset=0.):

        self.strides = [_pair(stride) for stride in strides]

        self.base_sizes = [min(stride) for stride in self.strides
                           ] if base_sizes is None else base_sizes

        assert ((octave_base_scale is not None
                and scales_per_octave is not None) ^ (scales is not None)), \
            'scales and octave_base_scale with scales_per_octave cannot' \
            ' be set at the same time'
        if scales is not None:
            self.scales = torch.Tensor(scales)

        elif octave_base_scale is not None and scales_per_octave is not None:
            octave_scales = np.array(
                [2**(i / scales_per_octave) for i in range(scales_per_octave)])
            scales = octave_scales * octave_base_scale
            self.scales = torch.Tensor(scales)

        self.base_anchors = self.gen_base_anchors()

现在具体看下gen_base_anchors方法:

    def gen_base_anchors(self):
        """Generate base anchors

        Returns:
            list(torch.Tensor): Base anchors of a feature grid in multiple
                feature levels.

"""
        multi_level_base_anchors = []
        for i, base_size in enumerate(self.base_sizes):
            center = None
            if self.centers is not None:
                center = self.centers[i]
            multi_level_base_anchors.append(
                self.gen_single_level_base_anchors(
                    base_size,
                    scales=self.scales,
                    ratios=self.ratios,
                    center=center))
        return multi_level_base_anchors

    def gen_single_level_base_anchors(self,
                                      base_size,
                                      scales,
                                      ratios,
                                      center=None):

        w = base_size
        h = base_size
        if center is None:
            x_center = self.center_offset * w
            y_center = self.center_offset * h
        else:
            x_center, y_center = center

        h_ratios = torch.sqrt(ratios)
        w_ratios = 1 / h_ratios
        if self.scale_major:

            ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
            hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
        else:
            ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
            hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)

        base_anchors = [
            x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
            y_center + 0.5 * hs
        ]
        base_anchors = torch.stack(base_anchors, dim=-1)

        return base_anchors

其实上面代码就是下图干的事情:就是stride * scales* ratios = 9

MMDet逐行解读之AnchorGenerator

2、grid_anchors的生成

在生成base_anchor基础上,之后需要通过改变每个anchor的中心来广播到整张特征图上面。以grid_anchors方法实现:

    def grid_anchors(self, featmap_sizes, device='cuda'):
        assert self.num_levels == len(featmap_sizes)
        multi_level_anchors = []
        for i in range(self.num_levels):
            anchors = self.single_level_grid_anchors(
                self.base_anchors[i].to(device),
                featmap_sizes[i],
                self.strides[i],
                device=device)
            multi_level_anchors.append(anchors)
        return multi_level_anchors

贴下single_level_grid_anchors方法

    def _meshgrid(self, x, y, row_major=True):
        """Generate mesh grid of x and y

        Args:
            x (torch.Tensor): Grids of x dimension.

            y (torch.Tensor): Grids of y dimension.

            row_major (bool, optional): Whether to return y grids first.

                Defaults to True.

        Returns:
            tuple[torch.Tensor]: The mesh grids of x and y.

"""
        xx = x.repeat(len(y))
        yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
        if row_major:
            return xx, yy
        else:
            return yy, xx

    def single_level_grid_anchors(self,
                                  base_anchors,
                                  featmap_size,
                                  stride=(16, 16),
                                  device='cuda'):
        """Generate grid anchors of a single level.

        Note:
            This function is usually called by method .grid_anchors.

        Args:
            base_anchors (torch.Tensor): The base anchors of a feature grid.

            featmap_size (tuple[int]): Size of the feature maps.

            stride (tuple[int], optional): Stride of the feature map.

                Defaults to (16, 16).

            device (str, optional): Device the tensor will be put on.

                Defaults to 'cuda'.

        Returns:
            torch.Tensor: Anchors in the overall feature maps.

"""
        feat_h, feat_w = featmap_size

        shift_x = torch.arange(0, feat_w, device=device) * stride[0]
        shift_y = torch.arange(0, feat_h, device=device) * stride[1]
        shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
        shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
        shifts = shifts.type_as(base_anchors)

        all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
        all_anchors = all_anchors.view(-1, 4)
        return all_anchors

3、valid_flags介绍

简单说下这个方法作用:在模型批次训练过程中,往往会对图像进行pad,pad会出现黑边,后面撒anchor会在pad部分也回撒上anchor,其实这部分anchor应该忽略掉。故该函数就是赋予每个anchor一个标签,若anchor在有效像素位置上,则Ture;否则赋为FALSE。

    def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
"""
        输入特征图原始尺寸和pad后尺寸
        Return:
            list(torch.Tensor):返回一个和anchor数量相等的bool型张量
"""
        assert self.num_levels == len(featmap_sizes)
        multi_level_flags = []
        for i in range(self.num_levels):
            anchor_stride = self.strides[i]
            feat_h, feat_w = featmap_sizes[i]
            h, w = pad_shape[:2]
            valid_feat_h = min(int(np.ceil(h / anchor_stride[0])), feat_h)
            valid_feat_w = min(int(np.ceil(w / anchor_stride[1])), feat_w)
            flags = self.single_level_valid_flags((feat_h, feat_w),
                                                  (valid_feat_h, valid_feat_w),
                                                  self.num_base_anchors[i],
                                                  device=device)
            multi_level_flags.append(flags)
        return multi_level_flags

    def single_level_valid_flags(self,
                                 featmap_size,
                                 valid_size,
                                 num_base_anchors,
                                 device='cuda'):
        """Generate the valid flags of anchor in a single feature map

        Args:
            featmap_size (tuple[int]): 原始特征图
            valid_size (tuple[int]): pad后有效尺寸
            num_base_anchors (int): 9
            device (str, optional): Device where the flags will be put on.

                Defaults to 'cuda'.

        Returns:
            torch.Tensor: The valid flags of each anchor in a single level
                feature map.

"""
        feat_h, feat_w = featmap_size
        valid_h, valid_w = valid_size
        assert valid_h  feat_h and valid_w  feat_w
        valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
        valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
        valid_x[:valid_w] = 1
        valid_y[:valid_h] = 1
        valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
        valid = valid_xx & valid_yy
        valid = valid[:, None].expand(valid.size(0),
                                      num_base_anchors).contiguous().view(-1)
        return valid

总结

下篇会介绍MaxIOUAssigner,敬请期待。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

Original: https://blog.csdn.net/wulele2/article/details/122409507
Author: 武乐乐~
Title: MMDet逐行解读之AnchorGenerator

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/686470/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球