【UNet3+】遥感影像分割

文章目录

1.1. 问题导入

  • 图像分割
    在计算机视觉领域,图像分割指的是将数字图像细分为多个图像子区域的过程,其目的是简化或改变图像的表示形式,使得图像更容易理解和分析。图像分割通常用于定位图像中的物体和边界,更精确的说,它是对图像中的每个像素加标签的一个过程,这一过程使得具有相同标签的像素具有某种共同视觉特性。
  • 实验任务
    本例简要介绍如何使用 UNet3+模型实现遥感影像分割,我们需要将遥感影像中存在的建筑物分割、标注出来。

1.2. 数据集简介

武汉大学2019年发布了Aerial Imagery Dataset,该数据集原始航拍数据来自新西兰土地信息服务网站,数据集共有8,189张具有0.3m分辨率、大小为512×512像素的遥感图像,数据集共包含18,7000座建筑物。数据集包含存放遥感图像的image文件夹和存放分割图像的label文件夹,例图如下图所示:

【UNet3+】遥感影像分割

这是数据集的下载链接:Aerial Imagery Dataset – AI Studio

; 2. UNet3+模型

2.1. 背景介绍

Hinton等人(2006)提出了一种 Encoder-Decoder结构,当时这个 Encoder-Decoder结构提出的主要作用并不是分割,而是压缩图像和去噪声。输入是一幅图,经过下采样的编码,得到一串比原先图像更小的特征,相当于压缩,然后再经过一个解码,理想状况就是能还原到原来的图像。

后来,Jonathan等人(2015)在论文中基于该拓扑结构提出了 FCN(Fully Convolutional Networks)。自提出以后, FCN就成为了语义分割的基本框架,后续算法(如 UNet)其实都是在这个框架中改进而来。其中的 UNet由于其对称结构简单易懂,且模型效果优秀,于是就成为了许多网络改进的范本之一。

UNet(2015)是医学影像分割领域应用最广泛的的网络,它使用跳跃连接(skip connection)来结合来自解码器的高级语义特征图和来自编码器的相应尺度的低级语义特征图,其性能和网络中多尺度特征的融合密切相关。为了避免纯跳跃连接在语义上融合不相似的特征,此后的 UNet++(2018)引入嵌套结构和密集的跳跃连接对网络进行了改进。而最新的 UNet3+(2020)通过全尺度的跳跃连接和深度监督(deep supervisions)来融合深层和浅层特征的同时对各个尺度的特征进行监督,它还可以在减少网络参数的同时提高计算效率。

【UNet3+】遥感影像分割

; 2.2. 模型介绍

Huang等人(2020)在论文中提出了 UNet3+模型,Huang等人使用该模型在肝脏和脾脏数据集上进行广泛的实验,发现它的表现得到了提高并且超过了很多baselines。下面介绍一下 UNet3+模型的三个创新点:

(1) 全尺度跳跃连接

UNet3+充分利用多尺度特征,引入全尺度跳跃连接(Full-scale Skip Connections),该连接结合了来自全尺度特征图的低级语义和高级语义,并且参数更少。

在许多分割实验的研究中,不同尺度的特征图展示着不同的信息:低级语义特征图捕捉丰富的空间信息,能够突出物体的边界;而高级语义特征图则体现了物体所在的位置信息。为此, UNet3+的每个解码器层都融合了来自编码器中的小尺度和同尺度的低级语义特征图,以及来自解码器的大尺度的高级语义特征图,这些特征图捕获了全尺度下的细粒度语义和粗粒度语义。

【UNet3+】遥感影像分割

如上图所示,为了构造特征图X D e 3 X_{De}^3 X De 3 ​,第3层解码器不仅需要接收同尺度编码器层的特征图X E n 3 X_{En}^3 X E n 3 ​,还需要接收小尺度编码器层的特征图X E n 1 X_{En}^1 X E n 1 ​和X E n 2 X_{En}^2 X E n 2 ​(为了统一特征图的分辨率,在接收前需进行下采样操作),同时也需要接收大尺度解码器层的特征图X D e 5 X_{De}^5 X De 5 ​和X D e 4 X_{De}^4 X De 4 ​(为了统一特征图的分辨率,在接收前需进行上采样操作)。在统一特征图的分辨率之后,我们还需用64个3×3的卷积核统一特征图的数量,以减少多余信息。在完成上述操作之后,我们就能用”通道维度拼接”的方法融合特征了,融合上述5个特征后便得到了320个特征图。接着,我们用320个3×3的卷积核对其进行卷积操作,最后通过批正则化(Batch Normalize)和ReLU(Rectified Linear Unit)便得到X D e 3 X_{De}^3 X De 3 ​。

于是,特征图X D e i X_{De}^i X De i ​的计算公式可总结为:

【UNet3+】遥感影像分割
其中,变量i i i表示沿着编码方向的编/解码层的编号,变量N N N表示编码器的总数,函数C C C代表卷积操作,函数U U U和D D D分别代表上采样和下采样操作,函数H H H代表”特征融合”机制(即1个卷积层+1个批正则化层+1个ReLU函数层),[ ] [ ][]代表”通道维度拼接”。

; (2) 全尺度深度监督

UNet3+采用全尺度深度监督(Full-scale Deep Supervision),从全面的聚合特征图中学习层次表示,优化了混合损失函数以增强器官边界。

不同于 UNet++对全分辨率特征图进行深度监督, UNet3+中每个解码器都有一个侧输出,它是由真实标准(ground truth)来进行监督的。为实现深度监督,每个解码器的侧输出都会被送入1个3×3卷积层、1个双线性上采样层以及1个sigmoid函数层中。

为了进一步增强器官边界, UNet3+提出了一种多尺度结构相似指数(Multi-Scale Structural Similarity index,MS-SSIM)损失函数来赋予模糊边界更大的权重。由于区域分布差异越大,MS-SSIM值越高,故 UNet3+将更加关注模糊边界。假设我们从分割结果 P和真实标准 G中分别裁剪了两个N×N的块p p p和g g g,并且有p = { p j : j = 1 , . . . , N 2 } p ={p_j : j = 1,…,N^2}p ={p j ​:j =1 ,…,N 2 }和g = { g j : j = 1 , . . . , N 2 } g ={g_j : j = 1,…,N^2}g ={g j ​:j =1 ,…,N 2 },那么我们可定义p p p和g g g的MS-SSIM损失函数为:

【UNet3+】遥感影像分割
其中,M M M表示尺度的总数(原作者将尺度总数设为5),μ p , μ g μ_p, μ_g μp ​,μg ​和σ p , σ g σ_p, σ_g σp ​,σg ​分别表示p p p和g g g的均值和方差,σ p g σ_{pg}σp g ​则表示p p p和g g g的协方差。β m , γ m β_m, γ_m βm ​,γm ​分别表示这两部分在每个尺度中的相对重要性程度,而设置小常量C 1 = 0.01 2 , C 2 = 0.03 2 C_1 = {0.01}^2, C_2 = {0.03}^2 C 1 ​=0.01 2 ,C 2 ​=0.03 2的目的是避免出现除以0的异常情况。

UNet3+融合了focal损失函数、MS-SSIM损失函数和IoU损失函数,提出了一种用于三个不同层次级别(像素级、块级、图像级)分割的混合损失函数,它能捕获边界清晰的大尺度结构和精细结构。该混合损失函数的定义为:

【UNet3+】遥感影像分割

(3) 分类指导模块

UNet3+提出分类指导模块(Classification-guided Module,CGM),通过图像级分类联合训练,减少非器官图像的过度分割。

在大多数医学图像分割实验中,由于来自背景的噪声信息停留在较浅层次中,这导致非器官图像出现过度分割的现象。为解决这一问题, UNet3+增加了一个预测输入图像是否有器官的额外分类任务。

【UNet3+】遥感影像分割

如上图所示,最深层的特征图X D e 5 X_{De}^5 X De 5 ​依次通过Dropout层、1×1卷积层、最大池化层和Sigmoid函数层,以得到代表X D e 5 X_{De}^5 X De 5 ​中有/无器官概率的二维张量。然后,我们可以用argmax函数处理二维张量,以得到仅包含0和1的二分类结果。接着,我们用这些分类结果与每个侧边分割输出相乘,以得到修正后的侧边分割输出。我们可以通过优化二分类的交叉损失函数,来获得更准确的分类结果,以此指导模型避免对非器官图像过度分割。

; 3. 代码实现

3.0. 前期准备

  • *导入模块

注意:本案例仅适用于 Paddle 2.0+版本,建议根据显存大小合理调整超参数 batch_sizeimg_size的大小!

import cv2
import os
import random
import zipfile
import numpy as np
from copy import deepcopy
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap as LSC

import paddle
from paddle import nn
from paddle.framework import ParamAttr
from paddle.io import DataLoader, Dataset
from paddle.nn import initializer as I, functional as F
from paddle.optimizer import Adam
from paddle.optimizer.lr import CosineAnnealingDecay
  • *设置超参数
BATCH_SIZE = 4
EPOCHS = 16
LOG_GAP = 500

N_CLASSES = 2
IMG_SIZE = (256, 256)

INIT_LR = 3e-4

SRC_PATH = "./data/data69911/BuildData.zip"
DST_PATH = "./data"
DATA_PATH = {
    "img": DST_PATH + "/image",
    "lab": DST_PATH + "/label",
}
INFER_PATH = {
    "img": ["./work/1.jpg", "./work/2.jpg"],
    "lab": ["./work/1.png", "./work/2.png"],
}
MODEL_PATH = "UNet3+.pdparams"

3.1. 数据准备

  • 解压数据集
    由于数据集中的数据是以压缩包的形式存放的,因此我们需要先解压数据压缩包。
if not os.path.isdir(DATA_PATH["img"]) or not os.path.isdir(DATA_PATH["lab"]):
    z = zipfile.ZipFile(SRC_PATH, "r")
    z.extractall(path=DST_PATH)
    z.close()
print("The dataset has been unpacked successfully!")
  • 划分数据集
    我们需要按9:1比例划分训练集和测试集,分别生成两个包含数据路径和标签路径映射关系的列表。
train_list, test_list = [], []
images = os.listdir(DATA_PATH["img"])

for idx, img in enumerate(images):
    lab = os.path.join(DATA_PATH["lab"], img.replace(".jpg", ".png"))
    img = os.path.join(DATA_PATH["img"], img)
    if idx % 10 != 0:
        train_list.append((img, lab))
    else:
        test_list.append((img, lab))
  • 数据增强
    数据増广(Data Augmentation),即数据增强,数据增强的目的主要是减少网络的过拟合现象,通过对训练图片进行变换可以得到泛化能力更强的网络,更好地适应应用场景。
    由于实验模型较为复杂,直接训练容易发生过拟合,故在处理实验数据集时采用数据增强的方法扩充数据集的多样性。本实验中用到的数据增强方法有:随机改变亮度,随机改变对比度,随机改变饱和度,随机改变清晰度,随机旋转图像,随机翻转图像,随机加高斯噪声等。
def random_brightness(img, lab, low=0.5, high=1.5):
    ''' 随机改变亮度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Brightness(img).enhance(x)
    return img, lab

def random_contrast(img, lab, low=0.5, high=1.5):
    ''' 随机改变对比度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Contrast(img).enhance(x)
    return img, lab

def random_color(img, lab, low=0.5, high=1.5):
    ''' 随机改变饱和度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Color(img).enhance(x)
    return img, lab

def random_sharpness(img, lab, low=0.5, high=1.5):
    ''' 随机改变清晰度(0.5~1.5) '''
    x = random.uniform(low, high)
    img = ImageEnhance.Sharpness(img).enhance(x)
    return img, lab

def random_rotate(img, lab, low=0, high=360):
    ''' 随机旋转图像(0~360度) '''
    angle = random.choice(range(low, high))
    img, lab = img.rotate(angle), lab.rotate(angle)
    return img, lab

def random_flip(img, lab, prob=0.5):
    ''' 随机翻转图像(p=0.5) '''
    if random.random() < prob:
        img = img.transpose(Image.FLIP_TOP_BOTTOM)
        lab = lab.transpose(Image.FLIP_TOP_BOTTOM)
    if random.random() < prob:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
        lab = lab.transpose(Image.FLIP_LEFT_RIGHT)
    return img, lab

def random_noise(img, lab, low=0, high=10):
    ''' 随机加高斯噪声(0~10) '''
    img = np.asarray(img)
    sigma = np.random.uniform(low, high)
    noise = np.random.randn(img.shape[0], img.shape[1], 3) * sigma
    img = img + np.round(noise).astype('uint8')

    img[img > 255], img[img < 0] = 255, 0
    img = Image.fromarray(img)
    return img, lab

def image_augment(img, lab, prob=0.5):
    ''' 叠加多种数据增强方法 '''
    opts = [random_brightness, random_contrast, random_color, random_flip,
            random_noise, random_rotate, random_sharpness,]
    for func in opts:
        if random.random() < prob:
            img, lab = func(img, lab)
    return img, lab
  • 数据预处理
    我们需要对数据集图像进行缩放和归一化处理。
class MyDataset(Dataset):
    ''' 自定义的数据集类
    * label_list: 图像路径和标签路径的映射列表
    * transform: 图像处理函数
    * augment: 数据增强函数
    '''
    def __init__(self, label_list, transform, augment=None):
        super(MyDataset, self).__init__()
        random.shuffle(label_list)
        self.label_list = label_list
        self.transform = transform
        self.augment = augment

    def __getitem__(self, index):
        ''' 根据位序获取对应数据 '''
        img_path, lab_path = self.label_list[index]
        img, lab = self.transform(img_path, lab_path, self.augment)
        return img, lab

    def __len__(self):
        ''' 获取数据集的样本总数 '''
        return len(self.label_list)

def data_mapper(img_path, lab_path, augment=None):
    ''' 图像处理函数 '''
    img = Image.open(img_path).convert("RGB")
    lab = cv2.cvtColor(cv2.imread(lab_path), cv2.COLOR_RGB2GRAY)

    _, lab = cv2.threshold(src=lab,
                           thresh=170,
                           maxval=255,
                           type=cv2.THRESH_BINARY_INV)
    lab = Image.fromarray(lab).convert("L")

    img = img.resize(IMG_SIZE, Image.ANTIALIAS)
    lab = lab.resize(IMG_SIZE, Image.ANTIALIAS)
    if augment is not None:
        img, lab = augment(img, lab)

    img = np.array(img).astype("float32").transpose((2, 0, 1))
    lab = np.array(lab).astype("int32")[np.newaxis, ...]

    img = paddle.to_tensor(img / 255.0)
    lab = paddle.to_tensor(lab // 255)
    return img, lab
train_dataset = MyDataset(train_list, data_mapper, image_augment)
test_dataset = MyDataset(test_list, data_mapper, augment=None)
  • 定义数据提供器
    我们需要分别构建用于训练和测试的数据提供器,其中训练数据提供器是乱序、按批次提供数据的。
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True,
                          drop_last=False)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         num_workers=4,
                         shuffle=False,
                         drop_last=False)

3.2. 网络配置

本次实验使用的是 UNet3+模型, UNet系列模型包含下采样(编码器,特征提取)和上采样(解码器,分辨率还原)两个阶段,因模型结构比较像U型而得名。

  • *定义网络初始化函数
def init_weights(net, init_type="normal"):
    ''' 初始化网络的权重与偏置
    * net: 需要初始化的神经网络层
    * init_type: 初始化机制(normal/xavier/kaiming/truncated)
    '''
    if init_type == "normal":
        attr = ParamAttr(initializer=I.Normal())
    elif init_type == "xavier":
        attr = ParamAttr(initializer=I.XavierNormal())
    elif init_type == "kaiming":
        attr = ParamAttr(initializer=I.KaimingNormal())
    elif init_type == "truncated":
        attr = ParamAttr(initializer=I.TruncatedNormal())
    else:
        error = "Initialization method [%s] is not implemented!"
        raise NotImplementedError(error % init_type)

    net.param_attr, net.bias_attr = attr, deepcopy(attr)
  • *构建编码器
class Encoder(nn.Layer):
    ''' 用于构建编码器模块
    * in_size: 输入通道数
    * out_size: 输出通道数
    * is_batchnorm: 是否批正则化
    * n: 卷积层数量(默认为2)
    * ks: 卷积核大小(默认为3)
    * s: 卷积运算步长(默认为1)
    * p: 卷积填充大小(默认为1)
    '''
    def __init__(self, in_size, out_size, is_batchnorm,
                 n=2, ks=3, s=1, p=1):
        super(Encoder, self).__init__()
        self.n = n

        for i in range(1, self.n+1):
            if is_batchnorm:
                block = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
                                      nn.BatchNorm2D(out_size),
                                      nn.ReLU())
            else:
                block = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
                                      nn.ReLU())
            setattr(self, "block%d" % i, block)
            in_size = out_size

        for m in self.children():
            init_weights(m, init_type="kaiming")

    def forward(self, x):
        for i in range(1, self.n+1):
            block = getattr(self, "block%d" % i)
            x = block(x)
        return x
  • *构建解码器
class Decoder(nn.Layer):
    ''' 用于构建解码器模块
    * cur_stage(int): 当前解码器所在层数
    * cat_size(int): 统一后的特征图通道数
    * up_size(int): 特征融合后的通道总数
    * filters(list): 各卷积网络的卷积核数
    * ks: 卷积核大小(默认为3)
    * s: 卷积运算步长(默认为1)
    * p: 卷积填充大小(默认为1)
    '''
    def __init__(self, cur_stage, cat_size, up_size,
                 filters, ks=3, s=1, p=1):
        super(Decoder, self).__init__()
        self.n = len(filters)

        for idx, num in enumerate(filters):
            idx += 1
            if idx < cur_stage:

                ps = 2 ** (cur_stage - idx)
                block = nn.Sequential(nn.MaxPool2D(ps, ps, ceil_mode=True),
                                      nn.Conv2D(num, cat_size, ks, s, p),
                                      nn.BatchNorm2D(cat_size),
                                      nn.ReLU())
            elif idx == cur_stage:

                block = nn.Sequential(nn.Conv2D(num, cat_size, ks, s, p),
                                      nn.BatchNorm2D(cat_size),
                                      nn.ReLU())
            else:

                us = 2 ** (idx - cur_stage)
                num = num if idx == 5 else up_size
                block = nn.Sequential(nn.Upsample(scale_factor=us, mode="bilinear"),
                                      nn.Conv2D(num, cat_size, ks, s, p),
                                      nn.BatchNorm2D(cat_size),
                                      nn.ReLU())
            setattr(self, "block%d" % idx, block)

        self.fusion = nn.Sequential(nn.Conv2D(up_size, up_size, ks, s, p),
                                    nn.BatchNorm2D(up_size),
                                    nn.ReLU())

        for m in self.children():
            init_weights(m, init_type="kaiming")

    def forward(self, inputs):
        outputs = []
        for i in range(self.n):
            block = getattr(self, "block%d" % (i+1))
            outputs.append( block(inputs[i]) )
        hd = self.fusion(paddle.concat(outputs, 1))
        return hd
  • *定义网络结构
class UNet3Plus(nn.Layer):
    ''' UNet3+ with Deep Supervision and Class-guided Module
    * in_channels: 输入通道数(默认为3)
    * n_classes: 物体的分类种数(默认为2)
    * is_batchnorm: 是否批正则化(默认为True)
    * deep_sup: 是否开启深度监督机制(Deep Supervision)
    * set_cgm: 是否设置分类引导模块(Class-guided Module)
    '''
    def __init__(self, in_channels=3, n_classes=2,
                 is_batchnorm=True, deep_sup=True, set_cgm=True):
        super(UNet3Plus, self).__init__()
        self.deep_sup = deep_sup
        self.set_cgm = set_cgm
        filters = [64, 128, 256, 512, 1024]
        cat_channels = filters[0]
        cat_blocks = 5
        up_channels = cat_channels * cat_blocks

        self.conv_e1 = Encoder(in_channels, filters[0], is_batchnorm)
        self.pool_e1 = nn.MaxPool2D(kernel_size=2)
        self.conv_e2 = Encoder(filters[0], filters[1], is_batchnorm)
        self.pool_e2 = nn.MaxPool2D(kernel_size=2)
        self.conv_e3 = Encoder(filters[1], filters[2], is_batchnorm)
        self.pool_e3 = nn.MaxPool2D(kernel_size=2)
        self.conv_e4 = Encoder(filters[2], filters[3], is_batchnorm)
        self.pool_e4 = nn.MaxPool2D(kernel_size=2)
        self.conv_e5 = Encoder(filters[3], filters[4], is_batchnorm)

        self.conv_d4 = Decoder(4, cat_channels, up_channels, filters)
        self.conv_d3 = Decoder(3, cat_channels, up_channels, filters)
        self.conv_d2 = Decoder(2, cat_channels, up_channels, filters)
        self.conv_d1 = Decoder(1, cat_channels, up_channels, filters)

        if self.set_cgm:

            self.cls = nn.Sequential(nn.Dropout(p=0.5),
                                     nn.Conv2D(filters[4], 2, 1),
                                     nn.AdaptiveMaxPool2D(1),
                                     nn.Sigmoid())
        if self.deep_sup:

            self.upscore5 = nn.Upsample(scale_factor=16, mode="bilinear")
            self.upscore4 = nn.Upsample(scale_factor=8, mode="bilinear")
            self.upscore3 = nn.Upsample(scale_factor=4, mode="bilinear")
            self.upscore2 = nn.Upsample(scale_factor=2, mode="bilinear")

            self.outconv5 = nn.Conv2D(filters[4], n_classes, 3, 1, 1)
            self.outconv4 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
            self.outconv3 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
            self.outconv2 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
        self.outconv1 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)

        for m in self.sublayers():
            if isinstance(m, nn.Conv2D) or isinstance(m, nn.BatchNorm):
                init_weights(m, init_type='kaiming')

    def dot_product(self, seg, cls):
        B, N, H, W = seg.shape
        seg = seg.reshape((B, N, H * W))
        clssp = paddle.ones((1, N))
        ecls = (cls * clssp).reshape((B, N, 1))
        final = (seg * ecls).reshape((B, N, H, W))
        return final

    def forward(self, x):

        e1 = self.conv_e1(x)
        e2 = self.pool_e1(self.conv_e2(e1))
        e3 = self.pool_e2(self.conv_e3(e2))
        e4 = self.pool_e3(self.conv_e4(e3))
        e5 = self.pool_e4(self.conv_e5(e4))

        if self.set_cgm:
            cls_branch = self.cls(e5).squeeze(3).squeeze(2)
            cls_branch_max = cls_branch.argmax(axis=1)
            cls_branch_max = cls_branch_max[:, np.newaxis].astype("float32")

        d5 = e5
        d4 = self.conv_d4((e1, e2, e3, e4, d5))
        d3 = self.conv_d3((e1, e2, e3, d4, d5))
        d2 = self.conv_d2((e1, e2, d3, d4, d5))
        d1 = self.conv_d1((e1, d2, d3, d4, d5))

        if self.deep_sup:
            y5 = self.upscore5( self.outconv5(d5) )
            y4 = self.upscore4( self.outconv4(d4) )
            y3 = self.upscore3( self.outconv3(d3) )
            y2 = self.upscore2( self.outconv2(d2) )
            y1 = self.outconv1(d1)
            if self.set_cgm:
                y5 = self.dot_product(y5, cls_branch_max)
                y4 = self.dot_product(y4, cls_branch_max)
                y3 = self.dot_product(y3, cls_branch_max)
                y2 = self.dot_product(y2, cls_branch_max)
                y1 = self.dot_product(y1, cls_branch_max)
            return F.sigmoid(y1), F.sigmoid(y2), F.sigmoid(y3),\
                   F.sigmoid(y4), F.sigmoid(y5)
        else:
            y1 = self.outconv1(d1)
            if self.set_cgm:
                y1 = self.dot_product(y1, cls_branch_max)
            return F.sigmoid(y1)
  • *实例化模型
model = UNet3Plus(n_classes=N_CLASSES, deep_sup=False, set_cgm=False)

  • *定义损失函数
class DiceLoss(nn.Layer):
    ''' Dice Loss for Segmentation Tasks'''

    def __init__(self,
                 n_classes: int = 2,
                 smooth: Union[float, Tuple[float, float]] = (0, 1e-6),
                 sigmoid_x: bool = False,
                 softmax_x: bool = True,
                 onehot_y: bool = True,
                 square_xy: bool = True,
                 include_bg: bool = True,
                 reduction: str = "mean"):
        ''' Args:
        * n_classes: number of classes.

        * smooth: smoothing parameters of the dice coefficient.

        * sigmoid_x: whether using sigmoid to process the result.

        * softmax_x: whether using softmax to process the result.

        * onehot_y: whether using one-hot to encode the label.

        * square_xy: whether using squared result and label.

        * include_bg: whether taking account of bg-class when computering dice.

        * reduction: reduction function of dice loss.

        '''
        super(DiceLoss, self).__init__()
        if reduction not in ["mean", "sum"]:
            raise NotImplementedError(
                "reduction of dice loss should be 'mean' or 'sum'!"
            )
        if isinstance(smooth, float):
            self.smooth = (smooth, smooth)
        else:
            self.smooth = smooth

        self.n_classes = n_classes
        self.sigmoid_x = sigmoid_x
        self.softmax_x = softmax_x
        self.onehot_y = onehot_y
        self.square_xy = square_xy
        self.include_bg = include_bg
        self.reduction = reduction

    def forward(self, pred, mask):
        (sm_nr, sm_dr) = self.smooth

        if self.sigmoid_x:
            pred = F.sigmoid(pred)
        if self.n_classes > 1:
            if self.softmax_x and self.n_classes == pred.shape[1]:
                pred = F.softmax(pred, axis=1)
            if self.onehot_y:
                mask = mask if mask.ndim < 4 else mask.squeeze(axis=1)
                mask = F.one_hot(mask.astype("int64"), self.n_classes)
                mask = mask.transpose((0, 3, 1, 2))
            if not self.include_bg:
                pred = pred[:, 1:] if pred.shape[1] > 1 else pred
                mask = mask[:, 1:] if mask.shape[1] > 1 else mask
        if pred.ndim != mask.ndim or pred.shape[1] != mask.shape[1]:
            raise ValueError(
                f"The shape of pred({pred.shape}) and " +
                f"mask({mask.shape}) should be the same."
            )

        reduce_dims = paddle.arange(2, pred.ndim).tolist()
        insersect = paddle.sum(pred * mask, axis=reduce_dims)
        if self.square_xy:
            pred, mask = paddle.pow(pred, 2), paddle.pow(mask, 2)
        pred_sum = paddle.sum(pred, axis=reduce_dims)
        mask_sum = paddle.sum(mask, axis=reduce_dims)
        loss = 1. - (2 * insersect + sm_nr) / (pred_sum + mask_sum + sm_dr)

        if self.reduction == "sum":
            loss = paddle.sum(loss)
        else:
            loss = paddle.mean(loss)
        return loss
  • *定义评估方法
def dice_func(pred: np.ndarray, mask: np.ndarray,
         n_classes: int, ignore_bg: bool = False):
    ''' compute dice (for NumpyArray) '''
    def sub_dice(x: paddle.Tensor, y: paddle.Tensor, sm: float = 1e-6):
        intersect = np.sum(np.sum(np.sum(x * y)))
        y_sum = np.sum(np.sum(np.sum(y)))
        x_sum = np.sum(np.sum(np.sum(x)))
        return (2 * intersect + sm) / (x_sum + y_sum + sm)

    assert pred.shape == mask.shape
    assert isinstance(ignore_bg, bool)
    return [
        sub_dice(pred == i, mask == i)
        for i in range(int(ignore_bg), n_classes)
    ]

3.3. 模型训练

model.train()
scheduler = CosineAnnealingDecay(
    learning_rate=INIT_LR,
    T_max=EPOCHS,
)
optimizer = AdamW(
    learning_rate=scheduler,
    parameters=model.parameters(),
    weight_decay=1e-5
)
dice_loss = DiceLoss(n_classes=N_CLASSES)
loss_list = []

for ep in range(EPOCHS):
    ep_loss_list = []
    for batch_id, data in enumerate(train_loader()):
        image, label = data
        pred = model(image)
        loss = dice_loss(pred, label)
        if batch_id % LOG_GAP == 0:
            print("Epoch:%2d,Batch:%3d,Loss:%.5f" % (ep, batch_id, loss))
        ep_loss_list.append(loss.item())
        optimizer.clear_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    loss_list.append(np.mean(ep_loss_list))
    print("【Train】Epoch:%2d,Loss:%.5f" % (ep, loss_list[-1]))
paddle.save(model.state_dict(), MODEL_PATH)

模型训练的结果如下:

Epoch&#xFF1A; 0&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.41813
Epoch&#xFF1A; 0&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.51309
Epoch&#xFF1A; 0&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.32444
Epoch&#xFF1A; 0&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.22436
Epoch&#xFF1A; 0&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.07805
&#x3010;Train&#x3011;Epoch&#xFF1A; 0&#xFF0C;Loss&#xFF1A;0.19494
Epoch&#xFF1A; 1&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.50250
Epoch&#xFF1A; 1&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.50011
Epoch&#xFF1A; 1&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.05158
Epoch&#xFF1A; 1&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.06440
Epoch&#xFF1A; 1&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.09458
&#x3010;Train&#x3011;Epoch&#xFF1A; 1&#xFF0C;Loss&#xFF1A;0.16005
Epoch&#xFF1A; 2&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.25998
Epoch&#xFF1A; 2&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.03422
Epoch&#xFF1A; 2&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.50014
Epoch&#xFF1A; 2&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.35213
Epoch&#xFF1A; 2&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.04104
&#x3010;Train&#x3011;Epoch&#xFF1A; 2&#xFF0C;Loss&#xFF1A;0.14942
Epoch&#xFF1A; 3&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.06672
Epoch&#xFF1A; 3&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.05075
Epoch&#xFF1A; 3&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.03801
Epoch&#xFF1A; 3&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.05001
Epoch&#xFF1A; 3&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.03976
&#x3010;Train&#x3011;Epoch&#xFF1A; 3&#xFF0C;Loss&#xFF1A;0.14288
Epoch&#xFF1A; 4&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.06034
Epoch&#xFF1A; 4&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.08312
Epoch&#xFF1A; 4&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.50062
Epoch&#xFF1A; 4&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.03367
Epoch&#xFF1A; 4&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.03980
&#x3010;Train&#x3011;Epoch&#xFF1A; 4&#xFF0C;Loss&#xFF1A;0.13926
Epoch&#xFF1A; 5&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.05745
Epoch&#xFF1A; 5&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.04486
Epoch&#xFF1A; 5&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.06463
Epoch&#xFF1A; 5&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.08085
Epoch&#xFF1A; 5&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.03778
&#x3010;Train&#x3011;Epoch&#xFF1A; 5&#xFF0C;Loss&#xFF1A;0.13551
Epoch&#xFF1A; 6&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.02407
Epoch&#xFF1A; 6&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.50000
Epoch&#xFF1A; 6&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.50007
Epoch&#xFF1A; 6&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.05890
Epoch&#xFF1A; 6&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.03876
&#x3010;Train&#x3011;Epoch&#xFF1A; 6&#xFF0C;Loss&#xFF1A;0.13283
Epoch&#xFF1A; 7&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.05039
Epoch&#xFF1A; 7&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.02733
Epoch&#xFF1A; 7&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.02768
Epoch&#xFF1A; 7&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.03542
Epoch&#xFF1A; 7&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.14349
&#x3010;Train&#x3011;Epoch&#xFF1A; 7&#xFF0C;Loss&#xFF1A;0.13040
Epoch&#xFF1A; 8&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.02584
Epoch&#xFF1A; 8&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.11713
Epoch&#xFF1A; 8&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.04467
Epoch&#xFF1A; 8&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.04462
Epoch&#xFF1A; 8&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.02022
&#x3010;Train&#x3011;Epoch&#xFF1A; 8&#xFF0C;Loss&#xFF1A;0.12809
Epoch&#xFF1A; 9&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.04599
Epoch&#xFF1A; 9&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.01690
Epoch&#xFF1A; 9&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.02768
Epoch&#xFF1A; 9&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.50053
Epoch&#xFF1A; 9&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.04013
&#x3010;Train&#x3011;Epoch&#xFF1A; 9&#xFF0C;Loss&#xFF1A;0.12536
Epoch&#xFF1A;10&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.03324
Epoch&#xFF1A;10&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.36780
Epoch&#xFF1A;10&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.03769
Epoch&#xFF1A;10&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.50011
Epoch&#xFF1A;10&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.50002
&#x3010;Train&#x3011;Epoch&#xFF1A;10&#xFF0C;Loss&#xFF1A;0.12356
Epoch&#xFF1A;11&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.03896
Epoch&#xFF1A;11&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.11487
Epoch&#xFF1A;11&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.03414
Epoch&#xFF1A;11&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.06988
Epoch&#xFF1A;11&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.05266
&#x3010;Train&#x3011;Epoch&#xFF1A;11&#xFF0C;Loss&#xFF1A;0.12255
Epoch&#xFF1A;12&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.02918
Epoch&#xFF1A;12&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.50000
Epoch&#xFF1A;12&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.50000
Epoch&#xFF1A;12&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.05509
Epoch&#xFF1A;12&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.06147
&#x3010;Train&#x3011;Epoch&#xFF1A;12&#xFF0C;Loss&#xFF1A;0.12097
Epoch&#xFF1A;13&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.03541
Epoch&#xFF1A;13&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.03809
Epoch&#xFF1A;13&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.04672
Epoch&#xFF1A;13&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.02856
Epoch&#xFF1A;13&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.02951
&#x3010;Train&#x3011;Epoch&#xFF1A;13&#xFF0C;Loss&#xFF1A;0.11975
Epoch&#xFF1A;14&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.06455
Epoch&#xFF1A;14&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.03240
Epoch&#xFF1A;14&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.05857
Epoch&#xFF1A;14&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.02092
Epoch&#xFF1A;14&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.02371
&#x3010;Train&#x3011;Epoch&#xFF1A;14&#xFF0C;Loss&#xFF1A;0.11936
Epoch&#xFF1A;15&#xFF0C;Batch&#xFF1A;  0&#xFF0C;Loss&#xFF1A;0.50000
Epoch&#xFF1A;15&#xFF0C;Batch&#xFF1A;500&#xFF0C;Loss&#xFF1A;0.03537
Epoch&#xFF1A;15&#xFF0C;Batch&#xFF1A;1000&#xFF0C;Loss&#xFF1A;0.50006
Epoch&#xFF1A;15&#xFF0C;Batch&#xFF1A;1500&#xFF0C;Loss&#xFF1A;0.05185
Epoch&#xFF1A;15&#xFF0C;Batch&#xFF1A;2000&#xFF0C;Loss&#xFF1A;0.50004
&#x3010;Train&#x3011;Epoch&#xFF1A;15&#xFF0C;Loss&#xFF1A;0.11859
  • *可视化训练过程
fig = plt.figure(figsize=[10, 5])

ax = fig.add_subplot(111, facecolor="#E8E8F8")
ax.set_xlabel("Steps", fontsize=18)
ax.set_ylabel("Loss", fontsize=18)
plt.tick_params(labelsize=14)
ax.plot(range(len(loss_list)), loss_list, color="orangered")
ax.grid(linewidth=1.5, color="white")

fig.tight_layout()
plt.show()
plt.close()

【UNet3+】遥感影像分割

3.4. 模型评估

model.eval()
model.set_state_dict(
    paddle.load(MODEL_PATH)
)
dice_accs = []

for batch_id, data in enumerate(test_loader()):
    image, label = data
    pred = model(image)
    pred = pred.argmax(axis=1).squeeze(axis=0).cpu().numpy()
    label = label.squeeze(0).squeeze(0).cpu().numpy()
    dice = dice_func(pred, label, N_CLASSES)
    dice_accs.append(dice)
print("Eval \t Dice: %.5f" % (np.mean(dice_accs)))

模型评估的结果如下:

Eval     Dice: 0.94400

3.5. 模型预测

def show_result(img_path, lab_path, pred):
    ''' 展示原图、标签以及预测结果 '''

    def add_subimg(img, loc, title, cmap=None):
        ''' 添加子图以展示图像 '''
        plt.subplot(loc)
        plt.title(title)
        plt.imshow(img, cmap)
        plt.xticks([])
        plt.yticks([])

    def colormap(colors=['#A0C185', '#A6A6A6']):
        ''' 自定义ColorMap '''
        return LSC.from_list('cmap', colors, 256)

    img = Image.open(img_path).resize(IMG_SIZE)
    lab = Image.open(lab_path).resize(IMG_SIZE)
    pred = pred.argmax(axis=1).numpy().reshape(IMG_SIZE)
    plt.figure(figsize=(12, 4))
    add_subimg(img, 131, "Image")
    add_subimg(lab, 132, "Label")
    add_subimg(pred, 133, "Predict", colormap())
    plt.tight_layout()
    plt.show()
    plt.close()
model.eval()
model.set_state_dict(
    paddle.load(MODEL_PATH)
)

for i in range(len(INFER_PATH["img"])):
    img_path, lab_path = INFER_PATH["img"][i], INFER_PATH["lab"][i]
    img, lab = data_mapper(img_path, lab_path)
    pred = model(img[np.newaxis, ...])
    show_result(img_path, lab_path, pred)

第1组图像分割结果如下:

【UNet3+】遥感影像分割

第2组图像分割结果如下:

【UNet3+】遥感影像分割

写在最后

  • 如果您发现项目存在问题,或者如果您有更好的建议,欢迎在下方评论区中留言讨论~
  • 这是本项目的链接:实验项目 – AI Studio,点击 fork可直接在AI Studio运行~
  • 这是我的个人主页:个人主页 – AI Studio,来AI Studio互粉吧,等你哦~
  • 【友链滴滴】欢迎大家随时访问我的个人博客~

Original: https://blog.csdn.net/TXK_Kevin/article/details/125114674
Author: Kevin Tang
Title: 【UNet3+】遥感影像分割

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/518405/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球