tensorflow2.2_实现SENet

  1. SENet介绍

SENet 是 ImageNet Challenge 图像识别比赛 2017 年的冠军,是来自 Momenta 公司的团队完成。他们提出了 Squeeze-and-Excitation Networks(简称 SENet)。SENet一般不单独使用,通常都是与其它模型结合使用,使其效果更好。
在一般的卷积层中通过卷积核会生成许多不同的特征图,但在这些特征图中并不是所有的特征图都是很重要的,也许有些特征可以忽略。如果我们可以将重要的特征加强,而不重要的特征可以减弱,这样我们的模型效果可能会更好。
所以SENet就可以实现这样的效果,它的核心思想是: 给特征图增加注意力和门控机制,增强重要的特征图的信息,减弱不重要的特征图的信息

tensorflow2.2_实现SENet
其中:
  • W、H、C分别代表图片的宽、高、通道数。
  • Global Pooling代表全局池化。
  • FC代表全连接层
  • ReLu、Sigmoid分别代表激活函数分别使用ReLu和Sigmoid。
  • r代表缩减率,意思是在第一个全连接层缩减的通道数。

如上图,左边是普通的残差结构,右边是加上了SENet的残差结构。
加上SENet后,首先是做平均池化,得到特征图的压缩特征。第二层进行全连接层,我们也可以使用1×1的卷积核来代替,效果是一样的之后使用ReLu激活函数。第三层就是全连接层,也可以使用1×1卷积来代替,之后使用Sigmoid函数,使输出范围在0~1之间,起到 门控的作用。Sigmoid输出的激活值最后会乘以初始残差结构最后一个卷积层的输出结果,对特征图的数值大小进行控制。如果是重要的特征图,会保持比较大的数值;如果是不重要的特征图,特征图的数值就会变小。
下图是论文中的图,和上面介绍的差不多。

tensorflow2.2_实现SENet
下图是论文中三个网络的模型结构。
tensorflow2.2_实现SENet
其都是在每个block后加上SENet。
下图是个模型的比较
tensorflow2.2_实现SENet

; 2. SEnet实现代码

def SE_block(x_0, r = 16):
    channels = x_0.shape[-1]
    x = GlobalAvgPool2D()(x_0)

    x = x[:, None, None, :]

    x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
    x = Activation('sigmoid')(x)
    x = Multiply()([x_0, x])

    return x
  1. Resnet18与SENet结合

Resnet50也可以在残差块结构的最后加上SENet。

from os import name
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (Dense, ZeroPadding2D, Conv2D,
                                     MaxPool2D, GlobalAvgPool2D, Input,
                                     BatchNormalization, Activation, Add, Multiply)
from tensorflow.keras.models import Model
from plot_model import plot_model
from tensorflow.python.keras.layers.pooling import AveragePooling2D

def SE_block(x_0, r = 16):
    channels = x_0.shape[-1]
    x = GlobalAvgPool2D()(x_0)

    x = x[:, None, None, :]

    x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
    x = Activation('sigmoid')(x)
    x = Multiply()([x_0, x])

    return x

def block(x, filters, strides=2, r=16, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = SE_block(x, r=r)

    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

def SE_Resnet18(inputs, classes):
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu')(inputs)
    x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, strides=1, conv_short=False)
    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, strides=1, conv_short=False)

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    return x

if __name__ == '__main__':
    is_show_picture = False
    inputs = Input(shape=(224,224,3))
    classes = 17
    model = Model(inputs=inputs, outputs=SE_Resnet18(inputs, classes))
    model.summary()
    print(len(model.layers))
    for i in range(len(model.layers)):
        print(i, model.layers[i])
    if is_show_picture:
        plot_model(model,
           to_file='./nets_picture/SE_Resnet18.png',
           )
        print("plot_model------------------------>")

Original: https://blog.csdn.net/qq_42025868/article/details/122490891
Author: Haohao+++
Title: tensorflow2.2_实现SENet

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/692099/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球