【DeeplabV3+】DeeplabV3+网络结构详解

文章目录

聊DeeplabV3+网络前,先看空洞卷积。

1 常规卷积与空洞卷积的对比

1.1 空洞卷积简介

空洞卷积(Dilated convolution)如下图所示,其中 r 表示两列之间的距离(r=1就是常规卷积了)。

【DeeplabV3+】DeeplabV3+网络结构详解
池化可以扩大感受野,降低数据维度,减少计算量,但是会损失信息,对于语义分割来说,这造成了发展瓶颈。

空洞卷积可以在扩大感受野的情况下不损失信息,但其实,空洞卷积的确没有损失信息,但是却没有用到所有的信息。

; 1.2 空洞卷积的优点

  • 扩大感受野:神经网络加深,单个像素感受野扩大,但特征图尺寸缩小,空间分辨率降低,为此,空洞卷积出现了,一方面感受野大了可以检测分割大目标,另一方面分辨率高了可以精确定位目标。
  • 捕获多尺度上下文信息:两列之间填充 (r-1) 个0,这个 r 可自己设置,不同 r 可得到不同尺度信息。

2 DeeplabV3+模型简介

DeeplabV3+是语义分割领域超nice的方法,模型效果非常好。

DeeplabV3+主要在模型的架构上作文章,引入了可任意控制编码器提取特征的分辨率,通过 空洞卷积平衡精度和耗时。

DeeplabV3+在Encoder部分引入了大量的空洞卷积(见第2节),在不损失信息的情况下,加大了感受野,让每个卷积输出都包含较大范围的信息。

【DeeplabV3+】DeeplabV3+网络结构详解
此图详细介绍,请看大佬的b站视频Pytorch 搭建自己的DeeplabV3+语义分割平台,强推此人!

在Encoder中,对压缩四次的初步有效特征层(也可以是三次,看需求)利用并行的空洞卷积(Atrous Convolution),分别用不同rate(也就是第1节中的 r )的Atrous Convolution进行特征提取,再进行concat合并,然后进行1×1卷积压缩特征。 —-Encoder得到绿色特征图,称之为 ASPP加强特征提取网络的构建

在Decoder中,对压缩两次的初步有效特征层利用1×1卷积调整通道数,再和空洞卷积后的有效特征层(Encoder部分的输出)上采样的结果进行堆叠,在完成堆叠后,进行两次深度可分离卷积,这个时候,我们就获得了一个最终的有效特征层,它是整张图片的特征浓缩。

得到最终的有效特征层后,利用一个1×1卷积进行通道调整,调整到Num_Classes;然后利用resize进行上采样使得最终输出层,宽高和输入图片一样。

; 3 DeeplabV3+网络代码

结合上图及代码注释理解即可,代码可运行。

import torch
import torch.nn as nn
import torch.nn.functional as F

from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        from functools import partial

        model           = mobilenetv2(pretrained)

        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)

        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:

            for i in range(self.down_idx[-2], self.down_idx[-1]):

                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )

            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )

    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):

        low_level_features = self.features[:4](x)

        x = self.features[4:](low_level_features)
        return low_level_features, x

class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch4 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )

        self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
                nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )

    def forward(self, x):
        [b, c, row, col] = x.size()

        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)

        global_feature = torch.mean(x,2,True)
        global_feature = torch.mean(global_feature,3,True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)

        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)

        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result

class DeepLab(nn.Module):
    def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
        super(DeepLab, self).__init__()
        if backbone=="mobilenet":

            self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 320
            low_level_channels = 24
        else:
            raise ValueError('Unsupported backbone - {}, Use mobilenet, xception.'.format(backbone))

        self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)

        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        self.cat_conv = nn.Sequential(
            nn.Conv2d(48+256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

    def forward(self, x):

        H, W = x.size(2), x.size(3)

        low_level_features, x = self.backbone(x)
        x = self.aspp(x)
        low_level_features = self.shortcut_conv(low_level_features)

        x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        x = self.cat_conv(torch.cat((x, low_level_features), dim=1))

        x = self.cls_conv(x)

        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x

if __name__ == "__main__":
    num_classes = 21
    model = DeepLab(num_classes, backbone="mobilenet", pretrained=False, downsample_factor=16)
    model.eval()
    print(model)

    from torchsummaryX import summary
    summary(model, torch.randn(1, 3, 512, 512))

输出:

`
DeepLab(
(backbone): MobileNetV2(
(features): Sequential(
(0): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): InvertedResidual(
(conv): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

164_cat_conv.Dropout_7 –
165_cls_conv 88.080384M

Original: https://blog.csdn.net/weixin_45377629/article/details/124083978
Author: 寻找永不遗憾
Title: 【DeeplabV3+】DeeplabV3+网络结构详解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/670859/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球