【超分辨率】【深度学习】SRCNN pytorch代码(附详细注释和数据集)

主要改进:

  1. 断点恢复,可以恢复训练。
  2. 注释掉原test.py的38行才是真正的超分辨率。
    即image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
    其中//代表整除的意思。
  3. model.py存在两个与原论文有出入,请仔细思考,如果想不出来可以联系我,但自己思考更有成就感。

关于第二点的注释可以知道,这份代码更注重于研究图像生成,改善的是图像细节而非分辨率。

这里主要是对代码进行讲解,对SRCNN不了解的同学可以先去参考其他博文。

有问题,不知道如何跑代码的同学联系: 809267697@qq.com

下面是这篇代码的步骤。

首先准备好数据集,这里以img-91作为训练集,Set5作为测试集。

运行prepare.py 将两个数据集转为h5格式。(测试集要在命令加上 –eval)

之后运行train.py

import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y

def train(args):
    h5_file = h5py.File(args.output_path, 'w')

    lr_patches = []
    hr_patches = []

    for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):

        hr = pil_image.open(image_path).convert('RGB')

        hr_width = (hr.width // args.scale) * args.scale
        hr_height = (hr.height // args.scale) * args.scale

        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)

        lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)

        lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)

        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)

        for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
            for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
                lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
                hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])

    lr_patches = np.array(lr_patches)
    hr_patches = np.array(hr_patches)

    h5_file.create_dataset('lr', data=lr_patches)
    h5_file.create_dataset('hr', data=hr_patches)

    h5_file.close()

def eval(args):
    h5_file = h5py.File(args.output_path, 'w')

    lr_group = h5_file.create_group('lr')
    hr_group = h5_file.create_group('hr')

    for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
        hr = pil_image.open(image_path).convert('RGB')
        hr_width = (hr.width // args.scale) * args.scale
        hr_height = (hr.height // args.scale) * args.scale
        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
        lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
        lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)

        lr_group.create_dataset(str(i), data=lr)
        hr_group.create_dataset(str(i), data=hr)

    h5_file.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--images-dir', type=str, required=True)
    parser.add_argument('--output-path', type=str, required=True)
    parser.add_argument('--patch-size', type=int, default=32)
    parser.add_argument('--stride', type=int, default=14)
    parser.add_argument('--scale', type=int, default=4)
    parser.add_argument('--eval', action='store_true')
    args = parser.parse_args()

    if not args.eval:
        train(args)
    else:
        eval(args)

之后运行,看不懂注释可以先去其他博文了解SRCNN的网络结构和训练过程。

import argparse
import os
import copy

import numpy as np
from torch import Tensor
import torch
from torch import nn
import torch.optim as optim

import torch.backends.cudnn as cudnn

from torch.utils.data.dataloader import DataLoader

from tqdm import tqdm

from model import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--num-epochs', type=int, default=400)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = SRCNN().to(device)

    criterion = nn.MSELoss()

    optimizer = optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': args.lr*0.1}
    ], lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(

                dataset=train_dataset,

                batch_size=args.batch_size,

                shuffle=True,

                num_workers=args.num_workers,

                pin_memory=True,

                drop_last=True)

    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    lossLog=[]
    psnrLog=[]

    for epoch in range(1, args.num_epochs + 1):

        model.train()

        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:

            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs))

            for data in train_dataloader:

                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()

                loss.backward()

                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        lossLog.append(np.array(epoch_losses.avg))

        np.savetxt("lossLog.txt", lossLog)

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        psnrLog.append(Tensor.cpu(epoch_psnr.avg))
        np.savetxt('psnrLog.txt', psnrLog)

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

        print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))

        torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))

    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

之后运行test.py就可以了,其中跟train.py差不多就不注释了。

test.py是放入图片、权重和倍数就行,会生成两张图片。

(a)是原图 (b)是bicubic (c)是SRCNN

Original: https://blog.csdn.net/zhanjuex/article/details/124344864
Author: zhanjuex
Title: 【超分辨率】【深度学习】SRCNN pytorch代码(附详细注释和数据集)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/731649/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球