[pytorch] Unet医学分割 代码详解

Unet医学分割 代码详解

U-Net for brain segmentation

基于深度学习分割算法在 PyTorch 中的 U-Net 实现,用于脑 MRI 中的 FLAIR 异常分割
github代码: U-Net for brain segmentation
kaggle代码: brain-segmentation-pytorch
数据集下载:Brain MRI segmentation
数据集很小,代码也很清晰,比较好实现

Unet 模型

[pytorch] Unet医学分割 代码详解
class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )
unet = UNet()
print(unet)
UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu1): ReLU(inplace=True)
    (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu2): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu1): ReLU(inplace=True)
    (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu2): ReLU(inplace=True)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder3): Sequential(
    (enc3conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc3norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc3relu1): ReLU(inplace=True)
    (enc3conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc3norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc3relu2): ReLU(inplace=True)
  )
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder4): Sequential(
    (enc4conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc4norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc4relu1): ReLU(inplace=True)
    (enc4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc4norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc4relu2): ReLU(inplace=True)
  )
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bottleneck): Sequential(
    (bottleneckconv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bottlenecknorm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bottleneckrelu1): ReLU(inplace=True)
    (bottleneckconv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bottlenecknorm2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bottleneckrelu2): ReLU(inplace=True)
  )
  (upconv4): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (decoder4): Sequential(
    (dec4conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec4norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec4relu1): ReLU(inplace=True)
    (dec4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec4norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec4relu2): ReLU(inplace=True)
  )
  (upconv3): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
  (decoder3): Sequential(
    (dec3conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec3norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec3relu1): ReLU(inplace=True)
    (dec3conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec3norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec3relu2): ReLU(inplace=True)
  )
  (upconv2): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  (decoder2): Sequential(
    (dec2conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec2relu1): ReLU(inplace=True)
    (dec2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec2relu2): ReLU(inplace=True)
  )
  (upconv1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
  (decoder1): Sequential(
    (dec1conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec1relu1): ReLU(inplace=True)
    (dec1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec1relu2): ReLU(inplace=True)
  )
  (conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
)

数据读取

遍历文件

首先遍历文件地址和名称,通过os.walk函数。dirpath 表示当前正在访问的文件夹路径, dirnames 表示该文件夹下的子目录名list, filenames表示该文件夹下的文件list.

[pytorch] Unet医学分割 代码详解
[pytorch] Unet医学分割 代码详解
[pytorch] Unet医学分割 代码详解
for (dirpath, dirnames, filenames) in os.walk(images_dir):
    for filename in sorted(
        filter(lambda f: ".tif" in f, filenames),
        key=lambda x: int(x.split(".")[-2].split("_")[4]),
    ):
        print(filename)
        filepath = os.path.join(dirpath, filename)
        print(filepath)
TCGA_CS_6665_20010817_1.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_1.tif
TCGA_CS_6665_20010817_1_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_1_mask.tif
TCGA_CS_6665_20010817_2_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_2_mask.tif
TCGA_CS_6665_20010817_2.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_2.tif
TCGA_CS_6665_20010817_3.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_3.tif
TCGA_CS_6665_20010817_3_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_3_mask.tif
TCGA_CS_6665_20010817_4_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_4_mask.tif
TCGA_CS_6665_20010817_4.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_4.tif
TCGA_CS_6665_20010817_5_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_5_mask.tif
TCGA_CS_6665_20010817_5.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_5.tif
TCGA_CS_6665_20010817_6.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_6.tif
TCGA_CS_6665_20010817_6_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_6_mask.tif
TCGA_CS_6665_20010817_7_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_7_mask.tif
TCGA_CS_6665_20010817_7.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_7.tif
TCGA_CS_6665_20010817_8.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_8.tif
TCGA_CS_6665_20010817_8_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_8_mask.tif
TCGA_CS_6665_20010817_9.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_9.tif
TCGA_CS_6665_20010817_9_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_9_mask.tif
TCGA_CS_6665_20010817_10_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_10_mask.tif
TCGA_CS_6665_20010817_10.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_10.tif
TCGA_CS_6665_20010817_11_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_11_mask.tif
TCGA_CS_6665_20010817_11.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_11.tif
TCGA_CS_6665_20010817_12.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_12.tif
TCGA_CS_6665_20010817_12_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_12_mask.tif
TCGA_CS_6665_20010817_13_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_13_mask.tif
TCGA_CS_6665_20010817_13.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_13.tif
TCGA_CS_6665_20010817_14_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_14_mask.tif
TCGA_CS_6665_20010817_14.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_14.tif
TCGA_CS_6665_20010817_15_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_15_mask.tif
TCGA_CS_6665_20010817_15.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_15.tif
TCGA_CS_6665_20010817_16.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_16.tif
TCGA_CS_6665_20010817_16_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_16_mask.tif
TCGA_CS_6665_20010817_17_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_17_mask.tif
TCGA_CS_6665_20010817_17.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_17.tif
TCGA_CS_6665_20010817_18_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_18_mask.tif
TCGA_CS_6665_20010817_18.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_18.tif
TCGA_CS_6665_20010817_19.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_19.tif
TCGA_CS_6665_20010817_19_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_19_mask.tif
TCGA_CS_6665_20010817_20_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_20_mask.tif
TCGA_CS_6665_20010817_20.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_20.tif
TCGA_CS_6665_20010817_21.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_21.tif
TCGA_CS_6665_20010817_21_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_21_mask.tif
TCGA_CS_6665_20010817_22.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_22.tif
TCGA_CS_6665_20010817_22_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_22_mask.tif
TCGA_CS_6665_20010817_23_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_23_mask.tif
TCGA_CS_6665_20010817_23.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_23.tif
TCGA_CS_6665_20010817_24_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_24_mask.tif
TCGA_CS_6665_20010817_24.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_24.tif
TCGA_CS_6669_20020102_1.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_1.tif
TCGA_CS_6669_20020102_1_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_1_mask.tif
TCGA_CS_6669_20020102_2_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_2_mask.tif
TCGA_CS_6669_20020102_2.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_2.tif
TCGA_CS_6669_20020102_3.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_3.tif
TCGA_CS_6669_20020102_3_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_3_mask.tif
TCGA_CS_6669_20020102_4.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_4.tif
TCGA_CS_6669_20020102_4_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_4_mask.tif
TCGA_CS_6669_20020102_5.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_5.tif
TCGA_CS_6669_20020102_5_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_5_mask.tif
TCGA_CS_6669_20020102_6_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_6_mask.tif
TCGA_CS_6669_20020102_6.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_6.tif
TCGA_CS_6669_20020102_7_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_7_mask.tif
TCGA_CS_6669_20020102_7.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_7.tif
TCGA_CS_6669_20020102_8.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_8.tif
TCGA_CS_6669_20020102_8_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_8_mask.tif
TCGA_CS_6669_20020102_9_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_9_mask.tif
TCGA_CS_6669_20020102_9.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_9.tif
TCGA_CS_6669_20020102_10_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_10_mask.tif
TCGA_CS_6669_20020102_10.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_10.tif
TCGA_CS_6669_20020102_11.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_11.tif
TCGA_CS_6669_20020102_11_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_11_mask.tif
TCGA_CS_6669_20020102_12_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_12_mask.tif
TCGA_CS_6669_20020102_12.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_12.tif
TCGA_CS_6669_20020102_13_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_13_mask.tif
TCGA_CS_6669_20020102_13.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_13.tif
TCGA_CS_6669_20020102_14_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_14_mask.tif
TCGA_CS_6669_20020102_14.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_14.tif
TCGA_CS_6669_20020102_15.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_15.tif
TCGA_CS_6669_20020102_15_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_15_mask.tif
TCGA_CS_6669_20020102_16_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_16_mask.tif
TCGA_CS_6669_20020102_16.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_16.tif
TCGA_CS_6669_20020102_17.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_17.tif
TCGA_CS_6669_20020102_17_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_17_mask.tif
TCGA_CS_6669_20020102_18_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_18_mask.tif
TCGA_CS_6669_20020102_18.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_18.tif
TCGA_CS_6669_20020102_19_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_19_mask.tif
TCGA_CS_6669_20020102_19.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_19.tif
TCGA_CS_6669_20020102_20_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_20_mask.tif
TCGA_CS_6669_20020102_20.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_20.tif
TCGA_CS_6669_20020102_21_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_21_mask.tif
TCGA_CS_6669_20020102_21.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_21.tif
TCGA_CS_6669_20020102_22_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_22_mask.tif
TCGA_CS_6669_20020102_22.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_22.tif

[pytorch] Unet医学分割 代码详解

读取数据


volumes = {}
masks = {}
images_dir = './archive/lgg-mri-segmentation/kaggle_3m'
print("reading {} images...".format("train"))
for (dirpath, dirnames, filenames) in os.walk(images_dir):
    image_slices = []
    mask_slices = []
    for filename in sorted(
        filter(lambda f: ".tif" in f, filenames),
        key=lambda x: int(x.split(".")[-2].split("_")[4]),
    ):
        filepath = os.path.join(dirpath, filename)
        if "mask" in filename:
            mask_slices.append(imread(filepath, as_gray=True))
        else:
            image_slices.append(imread(filepath))
    if len(image_slices) > 0:
        patient_id = dirpath.split("/")[-1]
        volumes[patient_id] = np.array(image_slices[1:-1])
        masks[patient_id] = np.array(mask_slices[1:-1])

将所以病人数据储存在dict格式中,这里作者舍去了每个病人第一个和最后一个图片image_slices[1:-1],不太清楚是为什么。

print(len(volumes))
print(len(volumes['TCGA_DU_8163_19961119']))

数据集划分

随机选择十个病人的数据作为验证集,剩下的是训练集

patients_list = sorted(volumes)
seed=42
subset = "train"

if not subset == "all":
    random.seed(seed)
    validation_patients = random.sample(patients_list, k=10)
    if subset == "validation":
        patients_list = validation_patients
    else:
        patients_list = sorted(
            list(set(patients_list).difference(validation_patients))
        )
print(patients_list)
['TCGA_CS_4941_19960909', 'TCGA_CS_4942_19970222', 'TCGA_CS_4943_20000902', 'TCGA_CS_5393_19990606', 'TCGA_CS_5395_19981004', 'TCGA_CS_5396_20010302', 'TCGA_CS_5397_20010315', 'TCGA_CS_6186_20000601', 'TCGA_CS_6188_20010812', 'TCGA_CS_6290_20000917', 'TCGA_CS_6665_20010817', 'TCGA_CS_6666_20011109', 'TCGA_CS_6669_20020102', 'TCGA_DU_5849_19950405', 'TCGA_DU_5852_19950709', 'TCGA_DU_5853_19950823', 'TCGA_DU_5854_19951104', 'TCGA_DU_5855_19951217', 'TCGA_DU_5871_19941206', 'TCGA_DU_5872_19950223', 'TCGA_DU_5874_19950510', 'TCGA_DU_6399_19830416', 'TCGA_DU_6400_19830518', 'TCGA_DU_6401_19831001', 'TCGA_DU_6405_19851005', 'TCGA_DU_6407_19860514', 'TCGA_DU_7008_19830723', 'TCGA_DU_7010_19860307', 'TCGA_DU_7013_19860523', 'TCGA_DU_7018_19911220', 'TCGA_DU_7019_19940908', 'TCGA_DU_7294_19890104', 'TCGA_DU_7298_19910324', 'TCGA_DU_7299_19910417', 'TCGA_DU_7300_19910814', 'TCGA_DU_7301_19911112', 'TCGA_DU_7302_19911203', 'TCGA_DU_7304_19930325', 'TCGA_DU_7306_19930512', 'TCGA_DU_7309_19960831', 'TCGA_DU_8162_19961029', 'TCGA_DU_8163_19961119', 'TCGA_DU_8164_19970111', 'TCGA_DU_8165_19970205', 'TCGA_DU_8166_19970322', 'TCGA_DU_8167_19970402', 'TCGA_DU_8168_19970503', 'TCGA_DU_A5TP_19970614', 'TCGA_DU_A5TR_19970726', 'TCGA_DU_A5TS_19970726', 'TCGA_DU_A5TT_19980318', 'TCGA_DU_A5TU_19980312', 'TCGA_DU_A5TW_19980228', 'TCGA_DU_A5TY_19970709', 'TCGA_EZ_7264_20010816', 'TCGA_FG_5962_20000626', 'TCGA_FG_5964_20010511', 'TCGA_FG_6688_20020215', 'TCGA_FG_6689_20020326', 'TCGA_FG_6690_20020226', 'TCGA_FG_6691_20020405', 'TCGA_FG_6692_20020606', 'TCGA_FG_7634_20000128', 'TCGA_FG_7637_20000922', 'TCGA_FG_7643_20021104', 'TCGA_FG_8189_20030516', 'TCGA_FG_A4MT_20020212', 'TCGA_FG_A4MU_20030903', 'TCGA_FG_A60K_20040224', 'TCGA_HT_7473_19970826', 'TCGA_HT_7475_19970918', 'TCGA_HT_7602_19951103', 'TCGA_HT_7605_19950916', 'TCGA_HT_7608_19940304', 'TCGA_HT_7680_19970202', 'TCGA_HT_7684_19950816', 'TCGA_HT_7686_19950629', 'TCGA_HT_7690_19960312', 'TCGA_HT_7693_19950520', 'TCGA_HT_7694_19950404', 'TCGA_HT_7855_19951020', 'TCGA_HT_7856_19950831', 'TCGA_HT_7860_19960513', 'TCGA_HT_7874_19950902', 'TCGA_HT_7877_19980917', 'TCGA_HT_7881_19981015', 'TCGA_HT_7882_19970125', 'TCGA_HT_7884_19980913', 'TCGA_HT_8018_19970411', 'TCGA_HT_8105_19980826', 'TCGA_HT_8106_19970727', 'TCGA_HT_8107_19980708', 'TCGA_HT_8111_19980330', 'TCGA_HT_8113_19930809', 'TCGA_HT_8114_19981030', 'TCGA_HT_8563_19981209', 'TCGA_HT_A5RC_19990831', 'TCGA_HT_A616_19991226', 'TCGA_HT_A61A_20000127', 'TCGA_HT_A61B_19991127']

数据增强

作者自己写的数据增强函数,同时对数据和mask做数据增强。

def crop_sample(x):
    volume, mask = x
    volume[volume < np.max(volume) * 0.1] = 0
    z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1)
    z_nonzero = np.nonzero(z_projection)
    z_min = np.min(z_nonzero)
    z_max = np.max(z_nonzero) + 1
    y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1)
    y_nonzero = np.nonzero(y_projection)
    y_min = np.min(y_nonzero)
    y_max = np.max(y_nonzero) + 1
    x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1)
    x_nonzero = np.nonzero(x_projection)
    x_min = np.min(x_nonzero)
    x_max = np.max(x_nonzero) + 1
    return (
        volume[z_min:z_max, y_min:y_max, x_min:x_max],
        mask[z_min:z_max, y_min:y_max, x_min:x_max],
    )

def pad_sample(x):
    volume, mask = x
    a = volume.shape[1]
    b = volume.shape[2]
    if a == b:
        return volume, mask
    diff = (max(a, b) - min(a, b)) / 2.0
    if a > b:
        padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff))))
    else:
        padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0))
    mask = np.pad(mask, padding, mode="constant", constant_values=0)
    padding = padding + ((0, 0),)
    volume = np.pad(volume, padding, mode="constant", constant_values=0)
    return volume, mask

def resize_sample(x, size=256):
    volume, mask = x
    v_shape = volume.shape
    out_shape = (v_shape[0], size, size)
    mask = resize(
        mask,
        output_shape=out_shape,
        order=0,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    out_shape = out_shape + (v_shape[3],)
    volume = resize(
        volume,
        output_shape=out_shape,
        order=2,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    return volume, mask

def normalize_volume(volume):
    p10 = np.percentile(volume, 10)
    p99 = np.percentile(volume, 99)
    volume = rescale_intensity(volume, in_range=(p10, p99))
    m = np.mean(volume, axis=(0, 1, 2))
    s = np.std(volume, axis=(0, 1, 2))
    volume = (volume - m) / s
    return volume

print("preprocessing {} volumes...".format(subset))

volumes_list = [(volumes[k], masks[k]) for k in patients_list]

print("cropping {} volumes...".format(subset))

volumes_list = [crop_sample(v) for v in volumes_list]

print("padding {} volumes...".format(subset))

volumes_list = [pad_sample(v) for v in volumes_list]

print("resizing {} volumes...".format(subset))

volumes_list = [resize_sample(v, size=image_size) for v in volumes_list]

print("normalizing {} volumes...".format(subset))

volumes_list = [(normalize_volume(v), m) for v, m in volumes_list]

[pytorch] Unet医学分割 代码详解
[pytorch] Unet医学分割 代码详解

根据mask计算出切片概率,随机采样的时候使用


slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in volumes_list]
slice_weights = [
    (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in slice_weights
]
print(len(slice_weights))
print(len(slice_weights[0]))

mask只有三维,我们需要增加一个维度


volumes_list = [(v, m[..., np.newaxis]) for (v, m) in volumes_list]

索引列表

一个是病人的索引,另一个是切片的索引


num_slices = [v.shape[0] for v, m in volumes_list]
patient_slice_index = list(
    zip(
        sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
        sum([list(range(x)) for x in num_slices], []),
    )
)

[pytorch] Unet医学分割 代码详解

zip() 将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
sum的作用

[pytorch] Unet医学分割 代码详解
最后的索引是这样的,第一个病人的地一个slice,第一个病人的地二个slice,第一个病人的地三个slice…第二个病人的地一个slice…

[pytorch] Unet医学分割 代码详解

getitem

首先__len__返回的是patient_slice_index的长度,也就是我们有多少组图像(数据+mask算一组),我们训练的时getitem每次产生一组2d数据来进行训练,数据图像当作输入,mask当作label。

idx = 50
patient = patient_slice_index[idx][0]
slice_n = patient_slice_index[idx][1]
v, m = volumes_list[patient]
image = v[slice_n]
mask = m[slice_n]
print(len(volumes_list))
print(v.shape)
print(m.shape)
print(image.shape)
print(mask.shape)
100
(18, 224, 224, 3)
(18, 224, 224, 1)
(224, 224, 3)
(224, 224, 1)

首先通过index随机选出来一个病人patient,然后对于这个病人,随机选一个slice。


image = image.transpose(2, 0, 1)
mask = mask.transpose(2, 0, 1)

image_tensor = torch.from_numpy(image.astype(np.float32))
mask_tensor = torch.from_numpy(mask.astype(np.float32))

最后因为训练需要tensor格式的数据,所以将通道放到第一位,然后转化为tensor

训练和验证

DSC

图像分割常用评价指标DSC。对于分割过程中的评价标准主要采用Dice相似系数(Dice Similariy Coefficient,DSC),Dice系数是一种集合相似度度量指标,通常用于计算两个样本的相似度,值的范围 0-1 ,分割结果最好时值为 1 ,最差时值为 0.

详情: 图像分割常用评价指标DSC、Hausdorff_95、IOU、PPV等

validation

训练过程没有太多要解释的,按pytorch正常流程走就行。训练结果

reading train images...

preprocessing train volumes...

cropping train volumes...

padding train volumes...

resizing train volumes...

normalizing train volumes...

done creating train dataset
reading validation images...

preprocessing validation volumes...

cropping validation volumes...

padding validation volumes...

resizing validation volumes...

normalizing validation volumes...

done creating validation dataset
epoch 1 | loss: 0.8733632518694951
epoch 1 | val_loss: 0.9460358023643494
epoch 1 | val_dsc: 0.1948670436467475
epoch 2 | loss: 0.8402993633196905
epoch 2 | val_loss: 0.931992749373118
epoch 2 | val_dsc: 0.4118475790023708
epoch 3 | loss: 0.8270544547301072
epoch 3 | val_loss: 0.9293850064277649
epoch 3 | val_dsc: 0.4527188003808261
epoch 4 | loss: 0.8215367862811456
epoch 4 | val_loss: 0.9270628492037455
epoch 4 | val_dsc: 0.6138052787773625
epoch 5 | loss: 0.8171369204154382
epoch 5 | val_loss: 0.9256295363108317
epoch 5 | val_dsc: 0.4911489058649566
epoch 6 | loss: 0.8134041795363793
epoch 6 | val_loss: 0.9244295557339987
epoch 6 | val_dsc: 0.6963948791553534
epoch 7 | loss: 0.8091526673390315
epoch 7 | val_loss: 0.9240273038546244
epoch 7 | val_dsc: 0.7177726293762504
epoch 8 | loss: 0.8052632143864265
epoch 8 | val_loss: 0.9218348860740662
epoch 8 | val_dsc: 0.7100711862449238
epoch 9 | loss: 0.801162777038721
epoch 9 | val_loss: 0.9223942359288534
epoch 9 | val_dsc: 0.7482642381530232
epoch 10 | loss: 0.7974189153084388
epoch 10 | val_loss: 0.9184039433797201
.....

Best validation mean DSC: 0.855153

看一下验证过程dsc的计算过程

validation_true = []
for i, data in enumerate(loader_valid):
    x, y_true = data
    x, y_true = x.to(device), y_true.to(device)
    print(y_true.shape)
    y_true_np = y_true.detach().cpu().numpy()
    print(y_true_np.shape)
    validation_true.extend(
        [y_true_np[s] for s in range(y_true_np.shape[0])]
    )
    print(len(validation_true))
    print(validation_true[0].shape)
    break
torch.Size([128, 1, 256, 256])
(128, 1, 256, 256)
128
(1, 256, 256)

这里我选择batc_size为128, 所以每次产生128张数据,将其保存到validation_true和validation_pred中。

mean_dsc = np.mean(
    dsc_per_volume(
        validation_pred,
        validation_true,
        loader_valid.dataset.patient_slice_index,
    )
)

在每个epoch结束之后,计算这个epoch所有数据的dsc均值。

if mean_dsc > best_validation_dsc:
    best_validation_dsc = mean_dsc
    torch.save(unet.state_dict(), os.path.join(weights, "unet.pt"))

每当出现更好的结果,我们将模型保存下来

预测

我们使用之前保存的模型进行图像预测

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

loader_train, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle)
loaders = {"train": loader_train, "valid": loader_valid}
unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels)
unet.to(device)
state_dict = torch.load(os.path.join('./', "unet.pt"))
unet.load_state_dict(state_dict)
unet.eval()

input_list = []
pred_list = []
true_list = []

for i, data in enumerate(loader_valid):
    x, y_true = data
    x, y_true = x.to(device), y_true.to(device)
    with torch.set_grad_enabled(False):
        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

然后他用这个函数处理预测出来的结果. 现在我们只是有了每个slice的结果,我们需要将这些预测出来的结果按病人划分到一起,这样我们才能更好的观察分割结果。

def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes

我们来看一下这个函数的用法
首先,我们先统计一下验证集/测试集中每个病人slice的个数。 在数据读取阶段,我们看过loader_valid.dataset.patient_slice_index的结果,他像二位坐标一样确定了哪张slice,第一个病人的第一张实力测,第一个病人的第二张slice…

[pytorch] Unet医学分割 代码详解

然后我们使用np.bincount统计二维i坐标的第一个维度,他的数量就是每个病人slice的数量

[pytorch] Unet医学分割 代码详解
然后根据每个病人slice的数量,将预测出来的结果划分到一起。
[pytorch] Unet医学分割 代码详解
[pytorch] Unet医学分割 代码详解
然后,计算出每个病人的dice并画出来
dsc_dist = dsc_distribution(volumes)

dsc_dist_plot = plot_dsc(dsc_dist)
imsave("./dsc.png", dsc_dist_plot)

[pytorch] Unet医学分割 代码详解
红线是均值,绿线是中位数
最后,我们还要看一下分割的效果
for p in volumes:
    x = volumes[p][0]
    y_pred = volumes[p][1]
    y_true = volumes[p][2]
    for s in range(x.shape[0]):
        image = gray2rgb(x[s, 1])
        image = outline(image, y_pred[s, 0], color=[255, 0, 0])
        image = outline(image, y_true[s, 0], color=[0, 255, 0])
        filename = "{}-{}.png".format(p, str(s).zfill(2))
        filepath = os.path.join("./resultat", filename)
        imsave(filepath, image)

在原始图像上画出预测出来的结果(红色)和gt(绿色)

[pytorch] Unet医学分割 代码详解

完整代码

依赖

import os
import random

from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from tqdm import tqdm
from skimage.exposure import rescale_intensity
from skimage.io import imread, imsave
from skimage.transform import resize, rescale, rotate
from torch.utils.data import Dataset
from torchvision.transforms import Compose

数据增强函数

def crop_sample(x):
    volume, mask = x
    volume[volume < np.max(volume) * 0.1] = 0
    z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1)
    z_nonzero = np.nonzero(z_projection)
    z_min = np.min(z_nonzero)
    z_max = np.max(z_nonzero) + 1
    y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1)
    y_nonzero = np.nonzero(y_projection)
    y_min = np.min(y_nonzero)
    y_max = np.max(y_nonzero) + 1
    x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1)
    x_nonzero = np.nonzero(x_projection)
    x_min = np.min(x_nonzero)
    x_max = np.max(x_nonzero) + 1
    return (
        volume[z_min:z_max, y_min:y_max, x_min:x_max],
        mask[z_min:z_max, y_min:y_max, x_min:x_max],
    )

def pad_sample(x):
    volume, mask = x
    a = volume.shape[1]
    b = volume.shape[2]
    if a == b:
        return volume, mask
    diff = (max(a, b) - min(a, b)) / 2.0
    if a > b:
        padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff))))
    else:
        padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0))
    mask = np.pad(mask, padding, mode="constant", constant_values=0)
    padding = padding + ((0, 0),)
    volume = np.pad(volume, padding, mode="constant", constant_values=0)
    return volume, mask

def resize_sample(x, size=256):
    volume, mask = x
    v_shape = volume.shape
    out_shape = (v_shape[0], size, size)
    mask = resize(
        mask,
        output_shape=out_shape,
        order=0,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    out_shape = out_shape + (v_shape[3],)
    volume = resize(
        volume,
        output_shape=out_shape,
        order=2,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    return volume, mask

def normalize_volume(volume):
    p10 = np.percentile(volume, 10)
    p99 = np.percentile(volume, 99)
    volume = rescale_intensity(volume, in_range=(p10, p99))
    m = np.mean(volume, axis=(0, 1, 2))
    s = np.std(volume, axis=(0, 1, 2))
    volume = (volume - m) / s
    return volume
def transforms(scale=None, angle=None, flip_prob=None):
    transform_list = []

    if scale is not None:
        transform_list.append(Scale(scale))
    if angle is not None:
        transform_list.append(Rotate(angle))
    if flip_prob is not None:
        transform_list.append(HorizontalFlip(flip_prob))

    return Compose(transform_list)

class Scale(object):

    def __init__(self, scale):
        self.scale = scale

    def __call__(self, sample):
        image, mask = sample

        img_size = image.shape[0]

        scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale)

        image = rescale(
            image,
            (scale, scale),
            multichannel=True,
            preserve_range=True,
            mode="constant",
            anti_aliasing=False,
        )
        mask = rescale(
            mask,
            (scale, scale),
            order=0,
            multichannel=True,
            preserve_range=True,
            mode="constant",
            anti_aliasing=False,
        )

        if scale < 1.0:
            diff = (img_size - image.shape[0]) / 2.0
            padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),)
            image = np.pad(image, padding, mode="constant", constant_values=0)
            mask = np.pad(mask, padding, mode="constant", constant_values=0)
        else:
            x_min = (image.shape[0] - img_size) // 2
            x_max = x_min + img_size
            image = image[x_min:x_max, x_min:x_max, ...]
            mask = mask[x_min:x_max, x_min:x_max, ...]

        return image, mask

class Rotate(object):

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, sample):
        image, mask = sample

        angle = np.random.uniform(low=-self.angle, high=self.angle)
        image = rotate(image, angle, resize=False, preserve_range=True, mode="constant")
        mask = rotate(
            mask, angle, resize=False, order=0, preserve_range=True, mode="constant"
        )
        return image, mask

class HorizontalFlip(object):

    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, sample):
        image, mask = sample

        if np.random.rand() > self.flip_prob:
            return image, mask

        image = np.fliplr(image).copy()
        mask = np.fliplr(mask).copy()

        return image, mask

读取数据

class BrainSegmentationDataset(Dataset):
    """Brain MRI dataset for FLAIR abnormality segmentation"""

    in_channels = 3
    out_channels = 1

    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"]

        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        for (dirpath, dirnames, filenames) in os.walk(images_dir):
            image_slices = []
            mask_slices = []
            for filename in sorted(
                filter(lambda f: ".tif" in f, filenames),
                key=lambda x: int(x.split(".")[-2].split("_")[4]),
            ):
                filepath = os.path.join(dirpath, filename)
                if "mask" in filename:
                    mask_slices.append(imread(filepath, as_gray=True))
                else:
                    image_slices.append(imread(filepath))
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1]
                volumes[patient_id] = np.array(image_slices[1:-1])
                masks[patient_id] = np.array(mask_slices[1:-1])

        self.patients = sorted(volumes)

        if not subset == "all":
            random.seed(seed)
            validation_patients = random.sample(self.patients, k=10)
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )

        print("preprocessing {} volumes...".format(subset))

        self.volumes = [(volumes[k], masks[k]) for k in self.patients]

        print("cropping {} volumes...".format(subset))

        self.volumes = [crop_sample(v) for v in self.volumes]

        print("padding {} volumes...".format(subset))

        self.volumes = [pad_sample(v) for v in self.volumes]

        print("resizing {} volumes...".format(subset))

        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]

        print("normalizing {} volumes...".format(subset))

        self.volumes = [(normalize_volume(v), m) for v, m in self.volumes]

        self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
        self.slice_weights = [
            (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights
        ]

        self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes]

        print("done creating {} dataset".format(subset))

        num_slices = [v.shape[0] for v, m in self.volumes]
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
                sum([list(range(x)) for x in num_slices], []),
            )
        )

        self.random_sampling = random_sampling

        self.transform = transform

    def __len__(self):
        return len(self.patient_slice_index)

    def __getitem__(self, idx):
        patient = self.patient_slice_index[idx][0]
        slice_n = self.patient_slice_index[idx][1]

        if self.random_sampling:
            patient = np.random.randint(len(self.volumes))
            slice_n = np.random.choice(
                range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient]
            )

        v, m = self.volumes[patient]
        image = v[slice_n]
        mask = m[slice_n]

        if self.transform is not None:
            image, mask = self.transform((image, mask))

        image = image.transpose(2, 0, 1)
        mask = mask.transpose(2, 0, 1)

        image_tensor = torch.from_numpy(image.astype(np.float32))
        mask_tensor = torch.from_numpy(mask.astype(np.float32))

        return image_tensor, mask_tensor
def data_loaders(batch_size, workers, image_size, aug_scale, aug_angle):
    dataset_train, dataset_valid = datasets('./archive/lgg-mri-segmentation/kaggle_3m', image_size, aug_scale, aug_angle)

    def worker_init(worker_id):
        np.random.seed(42 + worker_id)

    loader_train = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=workers,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=batch_size,
        drop_last=False,
        num_workers=workers,
        worker_init_fn=worker_init,
    )

    return loader_train, loader_valid
def datasets(images, image_size, aug_scale, aug_angle):
    train = BrainSegmentationDataset(
        images_dir=images,
        subset="train",
        image_size=image_size,
        transform=transforms(scale=aug_scale, angle=aug_angle, flip_prob=0.5),
    )
    valid = BrainSegmentationDataset(
        images_dir=images,
        subset="validation",
        image_size=image_size,
        random_sampling=False,
    )
    return train, valid

网络构建

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

metric计算

class DiceLoss(nn.Module):

    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1. - dsc

def log_images(x, y_true, y_pred, channel=1):
    images = []
    x_np = x[:, channel].cpu().numpy()
    y_true_np = y_true[:, 0].cpu().numpy()
    y_pred_np = y_pred[:, 0].cpu().numpy()
    for i in range(x_np.shape[0]):
        image = gray2rgb(np.squeeze(x_np[i]))
        image = outline(image, y_pred_np[i], color=[255, 0, 0])
        image = outline(image, y_true_np[i], color=[0, 255, 0])
        images.append(image)
    return images

def gray2rgb(image):
    w, h = image.shape
    image += np.abs(np.min(image))
    image_max = np.abs(np.max(image))
    if image_max > 0:
        image /= image_max
    ret = np.empty((w, h, 3), dtype=np.uint8)
    ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255
    return ret

def outline(image, mask, color):
    mask = np.round(mask)
    yy, xx = np.nonzero(mask)
    for y, x in zip(yy, xx):
        if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0:
            image[max(0, y) : y + 1, max(0, x) : x + 1] = color
    return image
def dsc(y_pred, y_true):
    y_pred = np.round(y_pred).astype(int)
    y_true = np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))

def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true)
    return dsc_dict

def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list

def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes

def log_loss_summary(loss, step, prefix=""):
    print("epoch {} | {}: {}".format(step + 1, prefix + "loss", np.mean(loss)))

def log_scalar_summary(tag, value, step):
    print("epoch {} | {}: {}".format(step + 1, tag, value))

def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.fromstring(s, np.uint8).reshape((height, width, 4))

训练与评估

batch_size = 128
epochs = 300
lr = 0.0001
workers = 8
weights = "./"
image_size = 256
aug_scale = 0.05
aug_angle = 15
def train_validate():
    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

    loader_train, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=lr)

    loss_train = []
    loss_valid = []

    step = 0

    for epoch in range(epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

            if phase == "train":
                log_loss_summary(loss_train, epoch)
                loss_train = []

            if phase == "valid":
                log_loss_summary(loss_valid, epoch, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                log_scalar_summary("val_dsc", mean_dsc, epoch)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(), os.path.join(weights, "unet.pt"))
                loss_valid = []

    print("\nBest validation mean DSC: {:4f}\n".format(best_validation_dsc))

    state_dict = torch.load(os.path.join(weights, "unet.pt"))
    unet.load_state_dict(state_dict)
    unet.eval()

    input_list = []
    pred_list = []
    true_list = []

    for i, data in enumerate(loader_valid):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)
        with torch.set_grad_enabled(False):
            y_pred = unet(x)
            y_pred_np = y_pred.detach().cpu().numpy()
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
            y_true_np = y_true.detach().cpu().numpy()
            true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
            x_np = x.detach().cpu().numpy()
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])

    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader_valid.dataset.patient_slice_index,
        loader_valid.dataset.patients,
    )

    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave("./dsc.png", dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join("./resultat", filename)
            imsave(filepath, image)
train_validate()

Original: https://blog.csdn.net/qq_38736504/article/details/124003427
Author: liyihao76
Title: [pytorch] Unet医学分割 代码详解

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/631636/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球