pytorch: 图像恢复问题的代码实现详解(derain,dehaze,deblur,denoise等通用)

文章目录

*

+ 前言
+ 数据集
+
* 训练数据集
* 评估数据集
* 测试数据集
+ 网络模型
+ 自定义工具包
+ 网络训练和测试
+ 结语

前言

图像恢复是一类图形去噪问题的集合,在深度学习中可以理解为监督回归问题,主要包括图像去雨、图像去雾、图像去噪,图像去模糊和图像去马赛克等内容,但利用 pytorch 实现的代码类似,只是在具体网络结构上略有区别。

以图像去雨为例,之前写过一篇图像去雨的 pytorch 实现文章: https://blog.csdn.net/Wenyuanbo/article/details/116541682,但因当时能力和水平有限,实现逻辑存在问题,最近重新整理分享一下,希望能对大家有所帮助,工程文件如图所示,数据集路径根据自己情况设置。

pytorch: 图像恢复问题的代码实现详解(derain,dehaze,deblur,denoise等通用)

; 数据集

利用监督回归方法实现图像去雨时,一般数据集为有雨图和无雨图成对存在,首先我喜欢习惯性的将所有成对数据分别从 0 到结束对应重新排序(这个其实不影响,具体自己设计即可),诸如 001, 002, 003…。

MyDataset.py

import os
import random
import torchvision.transforms.functional as ttf
from torch.utils.data import Dataset
from PIL import Image

训练数据集

训练数据集是用来整合训练数据的,将有雨图和无雨图分别对应进行剪切,转张量等操作。

class MyTrainDataSet(Dataset):
    def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
        super(MyTrainDataSet, self).__init__()

        self.inputPath = inputPathTrain
        self.inputImages = os.listdir(inputPathTrain)

        self.targetPath = targetPathTrain
        self.targetImages = os.listdir(targetPathTrain)

        self.ps = patch_size

    def __len__(self):
        return len(self.targetImages)

    def __getitem__(self, index):

        ps = self.ps
        index = index % len(self.targetImages)

        inputImagePath = os.path.join(self.inputPath, self.inputImages[index])
        inputImage = Image.open(inputImagePath).convert('RGB')

        targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
        targetImage = Image.open(targetImagePath).convert('RGB')

        inputImage = ttf.to_tensor(inputImage)
        targetImage = ttf.to_tensor(targetImage)

        hh, ww = targetImage.shape[1], targetImage.shape[2]

        rr = random.randint(0, hh-ps)
        cc = random.randint(0, ww-ps)

        input_ = inputImage[:, rr:rr+ps, cc:cc+ps]
        target = targetImage[:, rr:rr+ps, cc:cc+ps]

        return input_, target

评估数据集

在网络训练中,不一定最后一次训练的效果就是最好的。评估数据集是在每一个 epoch 训练结束后对网络训练的性能进行评估,目的在于将最好的一次训练结果保存。

class MyValueDataSet(Dataset):
    def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
        super(MyValueDataSet, self).__init__()

        self.inputPath = inputPathTrain
        self.inputImages = os.listdir(inputPathTrain)

        self.targetPath = targetPathTrain
        self.targetImages = os.listdir(targetPathTrain)

        self.ps = patch_size

    def __len__(self):
        return len(self.targetImages)

    def __getitem__(self, index):

        ps = self.ps
        index = index % len(self.targetImages)

        inputImagePath = os.path.join(self.inputPath, self.inputImages[index])
        inputImage = Image.open(inputImagePath).convert('RGB')

        targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
        targetImage = Image.open(targetImagePath).convert('RGB')

        inputImage = ttf.center_crop(inputImage, (ps, ps))
        targetImage = ttf.center_crop(targetImage, (ps, ps))

        input_ = ttf.to_tensor(inputImage)
        target = ttf.to_tensor(targetImage)

        return input_, target

测试数据集

测试数据集的目的是将输入有雨进行去雨得到去雨后的结果,注意输入一般是原图大小,不进行裁剪。

class MyTestDataSet(Dataset):
    def __init__(self, inputPathTest):
        super(MyTestDataSet, self).__init__()

        self.inputPath = inputPathTest
        self.inputImages = os.listdir(inputPathTest)

    def __len__(self):
        return len(self.inputImages)

    def __getitem__(self, index):
        index = index % len(self.inputImages)

        inputImagePath = os.path.join(self.inputPath, self.inputImages[index])
        inputImage = Image.open(inputImagePath).convert('RGB')

        input_ = ttf.to_tensor(inputImage)

        return input_

网络模型

以一个 5 层简单卷积神经网络为例子,具体网络自己设定。
NetModel.py

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.inconv = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True)
        )
        self.midconv = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
        )
        self.outconv = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
        )

    def forward(self, x):

        x = self.inconv(x)
        x = self.midconv(x)
        x = self.outconv(x)

        return x

自定义工具包

自定义工具包主要是一个计算峰值信噪比(PSNR)的方法用来对训练进行评估。

utils.py

import torch

def torchPSNR(tar_img, prd_img):
    imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
    rmse = (imdff**2).mean().sqrt()
    ps = 20*torch.log10(1/rmse)
    return ps

网络训练和测试

main.py

import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
import utils
from NetModel import Net
from MyDataset import *

if __name__ == '__main__':
    EPOCH = 100
    BATCH_SIZE = 18
    LEARNING_RATE = 1e-3
    loss_list = []
    best_psnr = 0
    best_epoch = 0

    inputPathTrain = 'E://Rain100H/inputTrain/'
    targetPathTrain = 'E://Rain100H/targetTrain/'
    inputPathTest = 'E://Rain100H/inputTest/'
    resultPathTest = 'E://Rain100H/resultTest/'
    targetPathTest = 'E://Rain100H/targetTest/'

    myNet = Net()
    myNet = myNet.cuda()
    criterion = nn.MSELoss().cuda()

    optimizer = optim.Adam(myNet.parameters(), lr=LEARNING_RATE)

    datasetTrain = MyTrainDataSet(inputPathTrain, targetPathTrain)

    trainLoader = DataLoader(dataset=datasetTrain, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)

    datasetValue = MyValueDataSet(inputPathTest, targetPathTest)
    valueLoader = DataLoader(dataset=datasetValue, batch_size=16, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)

    datasetTest = MyTestDataSet(inputPathTest)

    testLoader = DataLoader(dataset=datasetTest, batch_size=1, shuffle=False, drop_last=False, num_workers=6, pin_memory=True)

    print('-------------------------------------------------------------------------------------------------------')
    if os.path.exists('./model_best.pth'):
        myNet.load_state_dict(torch.load('./model_best.pth'))

    for epoch in range(EPOCH):
        myNet.train()
        iters = tqdm(trainLoader, file=sys.stdout)
        epochLoss = 0
        timeStart = time.time()
        for index, (x, y) in enumerate(iters, 0):

            myNet.zero_grad()
            optimizer.zero_grad()

            input_train, target = Variable(x).cuda(), Variable(y).cuda()
            output_train = myNet(input_train)

            loss = criterion(output_train, target)

            loss.backward()
            optimizer.step()
            epochLoss += loss.item()

            iters.set_description('Training !!!  Epoch %d / %d,  Batch Loss %.6f' % (epoch+1, EPOCH, loss.item()))

        myNet.eval()
        psnr_val_rgb = []
        for index, (x, y) in enumerate(valueLoader, 0):
            input_, target_value = x.cuda(), y.cuda()
            with torch.no_grad():
                output_value = myNet(input_)
            for output_value, target_value in zip(output_value, target_value):
                psnr_val_rgb.append(psnr(output_value, target_value))

        psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()

        if psnr_val_rgb > best_psnr:
            best_psnr = psnr_val_rgb
            best_epoch = epoch
            torch.save(myNet.state_dict(), 'model_best.pth')

        loss_list.append(epochLoss)
        torch.save(myNet.state_dict(), 'model.pth')
        timeEnd = time.time()
        print("------------------------------------------------------------")
        print("Epoch:  {}  Finished,  Time:  {:.4f} s,  Loss:  {:.6f}.".format(epoch+1, timeEnd-timeStart, epochLoss))
        print('-------------------------------------------------------------------------------------------------------')
    print("Training Process Finished ! Best Epoch : {} , Best PSNR : {:.2f}".format(best_epoch, best_psnr))

    print('--------------------------------------------------------------')
    myNet.load_state_dict(torch.load('./model_best.pth'))
    myNet.eval()

    with torch.no_grad():
        timeStart = time.time()
        for index, x in enumerate(tqdm(testLoader, desc='Testing !!! ', file=sys.stdout), 0):
            torch.cuda.empty_cache()
            input_test = x.cuda()
            output_test = myNet(input_test)
            save_image(output_test, resultPathTest + str(index+1).zfill(3) + tail)
        timeEnd = time.time()
        print('---------------------------------------------------------')
        print("Testing Process Finished !!! Time: {:.4f} s".format(timeEnd - timeStart))

    plt.figure(1)
    x = range(0, EPOCH)
    plt.xlabel('epoch')
    plt.ylabel('epoch loss')
    plt.plot(x, loss_list, 'r-')
    plt.show()

结语

关于图像恢复特别是图像去雨问题欢迎一起交流学习。

Original: https://blog.csdn.net/Wenyuanbo/article/details/120141926
Author: 听 风、
Title: pytorch: 图像恢复问题的代码实现详解(derain,dehaze,deblur,denoise等通用)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/628339/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球