U-Net论文详解

U-Net论文详解

UNet算法Pytorch实现: https://github.com/codecat0/CV/tree/main/Semantic_Segmentation/UNet

U-Net结构由一个用于捕获上下文信息的压缩路径和一个支持精确定位的对称扩展路径构成。实验结果表明可以从很少的图像进行端到端的训练,并在ISBI挑战上优于先前最优的方法(滑动窗口卷积网络),并获得了冠军

1. 背景介绍

卷积网络的典型应用是分类任务,其中图像的输出是一个单一的类标签。然而在许多视觉任务中,特别是生物医学图像处理中,期望的输出应该包含定位,即给每一个像素点分配一个类标签。

于是滑动窗口卷积网络通过提供像素点周围的局部区域来预测每个像素的类别标签。但是这样的方法存在两个缺点:

  1. 速度特别慢,网络必须为每一个窗口单元单独运行,并且窗口单元重合而导致大量冗余
  2. 在定位精度和上下文信息之间的权衡。大的窗口单元需要更多的max pooling层,这会降低精度;而小的窗口单元捕获的上下文信息较少。

于是本文提出了U-Net网络

2. U-Net网络架构

U-Net论文详解

网络是一个经典的全卷积网络。网络的输入是一张572×572经过镜像操作的图像。为了使得每次下采样后特征图的尺寸为偶数。

U-Net论文详解

网络的左侧为 压缩路径,由 4个block构成, 每个block由2个未padding的卷积和一个最大池化构成,其中每次卷积特征图的尺寸为减小2,最大池化后会缩小一半。

现在大部分采用same padding的卷积,这样就不用对输入进行镜像操作,而且在拼接压缩路径与对应的扩展路径也不用进行裁剪,而且裁剪会使得特征图不对称

网络的右侧为 扩展路径,同样由 4个block构成,每个block开始之前通过 反卷积将特征图的尺寸扩大一倍,然后与压缩路径对应的特征图拼接, 由于采用未padding的卷积,左侧压缩路径的特征图的尺寸比右侧扩展路径的特征图的大,所以需要先进行裁剪,使其大小相同,然后拼接,然后 经过两次未padding的卷积进一步提取特征

最后根据自己的任务,输出对应大小的预测特征图

现在大部分采用 双线性插值代替反卷积,而且效果会更好

; 3. 数据增强

我们主要通过平移和旋转不变性以及灰度值的变化来增强模型的鲁棒性,特别地, 任意的弹性形变对训练非常有帮助

4. Pytorch实现

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x_pooled = self.pool(x)
        return x, x_pooled

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.up_sample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x_prev, x):
        x = self.up_sample(x)
        x_shape = x.shape[2:]
        x_prev_shape = x.shape[2:]
        h_diff = x_prev_shape[0] - x_shape[0]
        w_diff = x_prev_shape[1] - x_shape[1]

        x_tmp = torch.zeros(x_prev.shape).to(x.device)
        x_tmp[:, :, h_diff//2: h_diff+x_shape[0], w_diff//2: x_shape[1]] = x
        x = torch.cat([x_prev, x_tmp], dim=1)
        x = self.block1(x)
        x = self.block2(x)
        return x

class UNet(nn.Module):

    def __init__(self, num_classes=2):
        super(UNet, self).__init__()

        self.down_sample1 = Encoder(in_channels=3, out_channels=64)
        self.down_sample2 = Encoder(in_channels=64, out_channels=128)
        self.down_sample3 = Encoder(in_channels=128, out_channels=256)
        self.down_sample4 = Encoder(in_channels=256, out_channels=512)

        self.mid1 = nn.Sequential(
            nn.Conv2d(512, 1024, 3, bias=False),
            nn.ReLU(inplace=True)
        )
        self.mid2 = nn.Sequential(
            nn.Conv2d(1024, 1024, 3, bias=False),
            nn.ReLU(inplace=True)
        )

        self.up_sample1 = Decoder(in_channels=1024, out_channels=512)
        self.up_sample2 = Decoder(in_channels=512, out_channels=256)
        self.up_sample3 = Decoder(in_channels=256, out_channels=128)
        self.up_sample4 = Decoder(in_channels=128, out_channels=64)

        self.classifier = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        x1, x = self.down_sample1(x)
        x2, x = self.down_sample2(x)
        x3, x = self.down_sample3(x)
        x4, x = self.down_sample4(x)

        x = self.mid1(x)
        x = self.mid2(x)

        x = self.up_sample1(x4, x)
        x = self.up_sample2(x3, x)
        x = self.up_sample3(x2, x)
        x = self.up_sample4(x1, x)

        x = self.classifier(x)
        return x

if __name__ == '__main__':
    input = torch.rand(1, 3, 384, 384)
    model = UNet(2)
    out = model(input)
    print(out.shape)

Original: https://blog.csdn.net/qq_42735631/article/details/122182266
Author: 何如千泷
Title: U-Net论文详解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/642158/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球