自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)

文章目录

简介

很早之前看了unet3+医学图像分割的论文,本来想直接去github找keras/Tensorflow的实现,奈何找到的似乎都和源码有一些出入,于是自己按照论文和源码写了一下,不过也不能保证和源码完全一致,发出来抛砖引玉。很多讲unet3+的博客都写的挺不错的,要想了解全文可以看看这篇翻译【UNet3+(UNet+++)论文解读 玖零猴】​,这篇文章也简单讲一下自己的理解。

unet3+论文
源码(Pytorch)

一、unet3+

简单来说,unet3+有三个特点:
1 跨尺度连接,防止语义在下采样/上采样之间存在损失
2 全尺度深监督,学习深层次的特征表示
3 为了消除医学图像中噪声导致的假阳性分割,提出一个分类指导模块
4 一个新的混合损失函数(TODO)

呃,前三点其实有不同的观点,我们稍后再谈。

[En]

Er, there are actually different points in the first three points, and we’ll talk about it later.

自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)
unet3+的网络结构如上图,总的来说还是非常易懂的,作者认为unet和unet++都没有做到跨尺度的特征图连接,于是想到将编码器不同尺度地信息传递到解码器,解码器中的信息也进行了跨层传递,以此减少信息丢失(真是简单粗暴=_=)。
自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)
以解码器3为例,解码器3融合了编码器1、2、3和解码器4、5的特征,这些特征通过最大池化(来自编码器的特征)或上采样(来自解码器的特征)调整到和解码器3一样的特征图大小,并且通过卷积层(源码里是卷积+BN+ReLu)将特征数调整到一致。这些拼接的特征图再经过一个卷积+BN+ReLu块输出特征就OK。
自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)
这张图解释了另外两个特点,一个是全尺度深监督,另一个是分类指导模块(CGM)。
全尺度深度监督计算所有解码器每层输出的损失函数。
[En]

Full-scale deep supervision calculates the loss function for the output of each layer of all decoders.

为了防止噪声导致的假阳性分割,作者提出了分类指导模块。分类指导模块是添加在网络瓶颈层(编码器底层,En5)的模块,这一层网络最深,特征图数量最多,且特征图最小,可能过滤掉了一定的噪声。作者在这一层后面添加了一个小的分类头(Dropout + Conv1x1 + Pooling + Sigmoid),这个分类头输出一个概率,表示输入图像中有无目标器官,将这个分类结果和分割头相乘,可以消除假阳性。

结果比较,直接看图叭:

自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)

特点讲完了,说说槽点:
1 全尺度连接好是好,而且作者特地提到了,unet3+的参数是少于unet和unet++的,但实际上训练需要的时间和占用的内存好像都更多一些,似乎是因为unet3+用到了更多的卷积操作(比如,unet解码器每层只需要2次卷积,但看看上面的Fig.2,unet3+的每层解码器需要6次卷积)
2 还没想好
3 CGM只是一个简单的模块,在我自己的实验中,就算加了Dropout也很快就过拟合了,图像分割头的验证集损失还在降低,CGM这边的损失函数却已经不降反升了。

; 二、完整代码(keras)

注:小朋友不懂事,写代码是为了好玩,不一定是对的,如果有问题,欢迎指出和讨论,转载请注明出处。

[En]

Note: children are not sensible, the code is written for fun, it is not necessarily correct, if there are problems, welcome to point out and discuss, reprint please indicate the source.

CGM输出这块的实现还是有待商榷的,我的代码里CGM和分割掩膜是分别输出的,所以后面要手动相乘一下。

1.引入库

import tensorflow as tf
import numpy as np
from keras.models import Model
from keras.layers import Conv2D, Input, concatenate, MaxPooling2D, UpSampling2D, Activation, BatchNormalization, LayerNormalization, Dropout, GlobalMaxPooling2D

2.辅助函数


def normalization(input_tensor, normalization):

    if normalization=='batch':
        return(BatchNormalization()(input_tensor))
    elif normalization=='layer':
        return(LayerNormalization()(input_tensor))
    elif normalization == None:
        return input_tensor
    else:
        raise ValueError('Invalid normalization')

def conv2d_block(input_tensor, filters, kernel_size,
                norm_type, use_residual, act_type='relu',
                double_features = False, dilation=[1, 1]):

    x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[0], use_bias=False, kernel_initializer='he_normal')(input_tensor)
    x = normalization(x, norm_type)
    x = Activation(act_type)(x)

    if double_features:
        filters *= 2

    x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[1], use_bias=False, kernel_initializer='he_normal')(x)
    x = normalization(x, norm_type)

    if use_residual:
        if K.int_shape(input_tensor)[-1] != K.int_shape(x)[-1]:
            shortcut = Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(input_tensor)
            shortcut = normalization(shortcut, norm_type)
            x = add([x, shortcut])
        else:
            x = add([x, input_tensor])

    x = Activation(act_type)(x)

    return x

def down_layer_2d(input_tensor, down_pattern, filters, norm_type=None):
    if down_pattern == 'maxpooling':
        x = MaxPooling2D(pool_size=(2, 2))(input_tensor)
    elif down_pattern == 'avgpooling':
        x = AveragePooling2D(pool_size=(2, 2))(input_tensor)
    elif down_pattern == 'conv':
        x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', use_bias=False if norm_type is None else True, kernel_initializer='he_normal')(input_tensor)
        normalization(x, norm_type)
    elif down_pattern == 'normconv':
        x = normalization(input_tensor, norm_type)
        x = Conv2D(filters, kernel_size=(2, 2), strides=(2, 2), padding='same', kernel_initializer='he_normal')(x)
    else:
        raise ValueError('Invalid down_pattern')
    return x

def conv_norm_act(input_tensor, filters, kernel_size , norm_type='batch', act_type='relu', dilation=1):
    output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
    output_tensor = normalization(output_tensor, normalization=norm_type)
    output_tensor = Activation(act_type)(output_tensor)
    return output_tensor

def aggregate(l1, l2, l3, l4, l5, filters, kernel_size, norm_type='batch', act_type='relu'):
    out = concatenate([l1, l2, l3, l4, l5], axis = -1)
    out = Conv2D(filters * 5, kernel_size, padding = 'same', use_bias=False if norm_type is not None else True, kernel_initializer = 'he_normal')(out)
    out = normalization(out, norm_type)
    out = Activation(act_type)(out)

    return out

def cgm_block(input_tensor, class_num, dropout_rate = 0.):
    x = Dropout(rate = dropout_rate)(input_tensor)
    x = Conv2D(class_num, 1, padding='same', kernel_initializer='he_normal')(x)

    x = GlobalMaxPooling2D()(x)
    x = Activation('sigmoid', name='cgm_output')(x)

    return x

3.搭建网络


def unet3p_2d(input_shape, initial_features=32, kernel_size=3,
              class_num=1, norm_type='batch', double_features=False,
              use_residual=False, down_pattern='maxpooling', using_deep_supervision=True,
              using_cgm=False, cgm_drop_rate=0.5, show_summary=True):
    '''
    input_shape: (height, width, channel)
    initial_features: int, 初始特征图数量,每次下采样特征图数量加倍, unet3+原文中用的是64
    kernel_size: int, 卷积核大小
    class_num: int, 图像分割的类别数
    norm_type: str, 标准化方式, 'batch' 或 'layer', unet3+使用的是BatchNormalization
    double_features: bool, 在conv2d_block模块中是否在第二个卷积中将特征图数量翻倍,3dunet论文中提出该方法可以避免瓶颈问题,通常可以设为False
    use_residual: bool, 编码器部分是否使用残差连接
    down_pattern: str, 下采样方式, 'maxpooling' 或 'avgpooling' 或 'conv' 或 'normconv', unet3+使用的是MaxPooling
    using_deep_supervision: bool, 是否使用全尺度深度监督
    using_cgm: bool, 是否使用分类指导模块(CGM)
    cgm_drop_rate: float, CGM模块中Dropout比率
    show_summary: bool, 是否显示模型概况
    '''

    if class_num == 1:
        last_layer_activation = 'sigmoid'
    else:
        last_layer_activation = 'softmax'

    inputs = Input(input_shape)

    xe1 = conv2d_block(input_tensor=inputs, filters=initial_features, kernel_size=kernel_size,
                    norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe1_pool = down_layer_2d(input_tensor=xe1, down_pattern=down_pattern, filters=initial_features)

    xe2 = conv2d_block(input_tensor=xe1_pool, filters=initial_features * 2, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe2_pool = down_layer_2d(input_tensor=xe2, down_pattern=down_pattern, filters=initial_features * 2)

    xe3 = conv2d_block(input_tensor=xe2_pool, filters=initial_features * 4, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe3_pool = down_layer_2d(input_tensor=xe3, down_pattern=down_pattern, filters=initial_features * 4)

    xe4 = conv2d_block(input_tensor=xe3_pool, filters=initial_features * 8, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)
    xe4_pool = down_layer_2d(input_tensor=xe4, down_pattern=down_pattern, filters=initial_features * 8)

    xe5 = conv2d_block(input_tensor=xe4_pool, filters=initial_features * 16, kernel_size=kernel_size,
                       norm_type=norm_type, double_features=double_features, use_residual=use_residual)

    if using_cgm:
        cgm = cgm_block(input_tensor = xe5 , class_num = class_num ,dropout_rate = cgm_drop_rate)

    xd4_from_xe5 = UpSampling2D(size=(2,2), interpolation='bilinear')(xe5)
    xd4_from_xe5 = conv_norm_act(input_tensor=xd4_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe4 = conv_norm_act(input_tensor=xe4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe3 = MaxPooling2D(pool_size = (2, 2))(xe3)
    xd4_from_xe3 = conv_norm_act(input_tensor=xd4_from_xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe2 = MaxPooling2D(pool_size = (4, 4))(xe2)
    xd4_from_xe2 = conv_norm_act(input_tensor=xd4_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4_from_xe1 = MaxPooling2D(pool_size = (8, 8))(xe1)
    xd4_from_xe1 = conv_norm_act(input_tensor=xd4_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd4 = aggregate(xd4_from_xe5, xd4_from_xe4, xd4_from_xe3, xd4_from_xe2, xd4_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd3_from_xe5 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xe5)
    xd3_from_xe5 = conv_norm_act(input_tensor=xd3_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xd4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd4)
    xd3_from_xd4 = conv_norm_act(input_tensor=xd3_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe3 = conv_norm_act(input_tensor=xe3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe2 = MaxPooling2D(pool_size = (2, 2))(xe2)
    xd3_from_xe2 = conv_norm_act(input_tensor=xd3_from_xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3_from_xe1 = MaxPooling2D(pool_size = (4, 4))(xe1)
    xd3_from_xe1 = conv_norm_act(input_tensor=xd3_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd3 = aggregate(xd3_from_xe5, xd3_from_xd4, xd3_from_xe3, xd3_from_xe2, xd3_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd2_from_xe5 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xe5)
    xd2_from_xe5 = conv_norm_act(input_tensor=xd2_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xd4 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd4)
    xd2_from_xd4 = conv_norm_act(input_tensor=xd2_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xd3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd3)
    xd2_from_xd3 = conv_norm_act(input_tensor=xd2_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xe2 = conv_norm_act(input_tensor=xe2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2_from_xe1 = MaxPooling2D(pool_size = (2, 2))(xe1)
    xd2_from_xe1 = conv_norm_act(input_tensor=xd2_from_xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd2 = aggregate(xd2_from_xe5, xd2_from_xd4, xd2_from_xd3, xd2_from_xe2, xd2_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    xd1_from_xe5 = UpSampling2D(size=(16, 16), interpolation='bilinear')(xe5)
    xd1_from_xe5 = conv_norm_act(input_tensor=xd1_from_xe5, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd4 = UpSampling2D(size=(8, 8), interpolation='bilinear')(xd4)
    xd1_from_xd4 = conv_norm_act(input_tensor=xd1_from_xd4, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd3 = UpSampling2D(size=(4, 4), interpolation='bilinear')(xd3)
    xd1_from_xd3 = conv_norm_act(input_tensor=xd1_from_xd3, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xd2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(xd2)
    xd1_from_xd2 = conv_norm_act(input_tensor=xd1_from_xd2, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1_from_xe1 = conv_norm_act(input_tensor=xe1, filters=initial_features, kernel_size=kernel_size ,norm_type=norm_type)
    xd1 = aggregate(xd1_from_xe5, xd1_from_xd4, xd1_from_xd3, xd1_from_xd2, xd1_from_xe1, filters=initial_features, kernel_size=kernel_size, norm_type=norm_type)

    if using_deep_supervision:
        xd55 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xe5)
        xd55 = UpSampling2D(size=(16, 16))(xd55)
        xd55 = Activation(last_layer_activation, name='output_de5')(xd55)

        xd44 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd4)
        xd44 = UpSampling2D(size=(8, 8))(xd44)
        xd44 = Activation(last_layer_activation, name='output_de4')(xd44)

        xd33 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd3)
        xd33 = UpSampling2D(size=(4, 4))(xd33)
        xd33 = Activation(last_layer_activation, name='output_de3')(xd33)

        xd22 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd2)
        xd22 = UpSampling2D(size=(2, 2))(xd22)
        xd22 = Activation(last_layer_activation, name='output_de2')(xd22)

        xd11 = Conv2D(class_num, kernel_size, activation=None, padding='same')(xd1)
        xd11 = Activation(last_layer_activation, name='output_de1')(xd11)

        if using_cgm: outputs=[xd11, xd22, xd33, xd44, xd55, cgm]
        else: outputs=[xd11, xd22, xd33, xd44, xd55]

    else:
        conv_output = Conv2D(class_num, 1, activation=last_layer_activation, name='output')(xd1)
        if using_cgm: outputs=[conv_output, cgm]
        else: outputs = conv_output

    model = Model(inputs, outputs)
    if show_summary: model.summary()

    return model

4.创建模型

如果以上代码都在同一个.py文件下,可以加上以下代码尝试构建网络:

if __name__ == '__main__':
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling',
                      using_deep_supervision=True, using_cgm=False, show_summary=True)
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling',
                      using_deep_supervision=True, using_cgm=True, show_summary=True)
    model = unet3p_2d(input_shape=(256, 256, 1), initial_features=32, kernel_size=3,
                      class_num=1, norm_type='batch', double_features=False,
                      use_residual=False, down_pattern='maxpooling',
                      using_deep_supervision=False, using_cgm=False, show_summary=True)

如果用到了预训练的主干网络,需要修改下编码器(En)部分。

感觉自己好菜,不知道能不能顺利be yeah,哎TAT

Original: https://blog.csdn.net/weixin_42723174/article/details/125306304
Author: 求你涨点吧
Title: 自己实现的unet3+模型,以及简单分析 (unet3plus tensorflow2 keras)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/496178/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球