一天学会应用GAN扩充数据集(pytorch)

文章目录

前言

GAN对抗生成网络可以在数据集量少不足的情况下,根据这部分少量的数据集的特征来生成更多的新的数据集达到数据集扩充的目的,这篇文章前面部分先做个大概介绍后面有实例,都比较简单好理解,不想看理论的小伙伴可以直接跳到代码。

另外说一下,这篇文章是更新的第二个版本,去年写的是GAN的基础版本,介于很多朋友在问DCGAN和想要生成彩色图的缘故,于是做了一些改进,将GAN加入CNN变成了DCGAN之后,使得网络对于图像的特征提取能力更强参数也更少,效果有了非常好的改善,并且支持生成彩色图像(只需要修改代码前几行中的参数)

一、GAN基本原理

1.GAN结构图

一天学会应用GAN扩充数据集(pytorch)
GAN由两个模型构成, 判别模型和生成模型, 判别模型可用于训练, 也可用于测试, 但生成模型只能用于测试。生成模型捕捉真实样本的分布, 并根据分布生成新的fake样本;判别器是判别输入是真实样本还是fake样本的二分类器。模型G和D通过不断的对抗训练,使D正确判别训练样本来源,同时使G生成的fake样本与真实样本更相像。

; 2.GAN目标函数

一天学会应用GAN扩充数据集(pytorch)
GAN是生成网络和判别网络的博弈问题,判别网络D希望真实样本x的概率值越大越好,同时希望判断fake样本G(z)为真实样本的概率值越小越好,而生成网络G希望fake样本G(z)与x越相似越好,让判别网络判断其为真实样本的概率D(G(z))越高越好。

二、实例(完整代码:https://github.com/Programmerfei/Pytorch-Gan-based-dataset-expansion.git)

1.项目流程图

(这个流程图是用原始train训练的模型一和扩充的fake加上train训练的模型二准确率的对比流程图,如果只是想通过GAN生成数据就只参考这个流程图的左半部分)

一天学会应用GAN扩充数据集(pytorch)

流程图说明:1.将原始数据划分为train和val。 2.把train的图片送入GAN网络训练得到GAN的生成模型和判别模型,同时将train的图片送入CNN网络中训练得到第一个识别模型。 3.随机生成一些噪声点输入到步骤2中训练的生成模型中,得到若干输出的fake图片 4.将步骤3得到的fake图片和train的图片组合得到一个在原始数据集上加入了fake样本进行扩充后的新训练集 5.将新的训练集送入与步骤1相同的CNN网络中训练得到第二个识别模型 6.将val的图片送入步骤2和步骤5得到的两个识别模型中,对比预测准确率得到实验结论:用GAN生成的fake样本加入到识别模型的训练当中可以有效提高模型的泛化能力从而提高识别准确率。

; 2.项目代码

(本文只写训练GAN和用生成网络做数据扩充的代码,也就是流程图的左边部分)
注意:运行代码之前先将代码和数据按照目录结构放好,避免找不到库或数据

2.1解析mnist二进制文件保存为图片

(可以用其它数据集,训练什么类型就可以生成什么类型)
解析代码:

import os
import struct
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

MNIST_data_dir = 'data/MNIST/raw/'
train_val_data_dir = 'data/MNIST/'
Number_of_requirements = 500

def read_idx(filename):
"""
    二进制文件解析函数
    filename:二进制文件路径
"""
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

def save_img(data, labels, t_v):
"""
    图片保存函数
    data: 二进制文件解析出来的图片数据
    labels: 标签
    t_v: train或val
"""
    count_dict = {}
    for i in tqdm(range(len(data)), desc=t_v):
        label = labels[i]
        folder = os.path.join(t_v, str(label))
        if not os.path.exists(folder):
            os.makedirs(folder)
        if sum(count_dict.values()) == 10*Number_of_requirements:
            break

        if str(label) in count_dict and count_dict[str(label)] == Number_of_requirements:
            continue

        cv2.imwrite(os.path.join(folder, f'image_{i}.jpg'), data[i])

        count_dict[str(label)] = count_dict[str(label)] + \
            1 if str(label) in count_dict else 1
    print('数量已达要求,停止解析:\n', count_dict)

if __name__ == '__main__':
    for data_path, label_path, t_v in zip(['train-images-idx3-ubyte', 't10k-images-idx3-ubyte'],
                                          ['train-labels-idx1-ubyte',
                                              't10k-labels-idx1-ubyte'],
                                          ['train', 'val']):
        data = read_idx(os.path.join(MNIST_data_dir, data_path))
        labels = read_idx(os.path.join(
            MNIST_data_dir, label_path))
        save_img(data, labels, os.path.join(train_val_data_dir, t_v))

2.2训练GAN生成网络和判别网络

注意:修改图片路径和模型保存路径,导入库文件是否存在


import matplotlib.pyplot as plt
import os
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from utils.data_reader import LoadData
from utils.network_structure import generator, discriminator
from utils.functional_functions import init_weights

route = 'data\MNIST'
result_save_path = 'model/GAN_model'
drop_last = False
if not os.path.exists(result_save_path):
    os.mkdir(result_save_path)

lr_d = 0.002
lr_g = 0.002
batch_size = 100
num_epoch = 300
output_loss_Interval_ratio = 10
save_model_Interval_ratio = 100

g_d_nc = 1
g_input = 100

criterion = nn.BCELoss()

d = discriminator(number_of_channels=g_d_nc).cuda()
g = generator(noise_number=g_input,
              number_of_channels=g_d_nc).cuda()

d_optimizer = torch.optim.Adam(
    d.parameters(), lr=lr_d, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(g.parameters(), lr=lr_g, betas=(0.5, 0.999))

for number in range(0, 10):

    d.apply(init_weights), g.apply(init_weights)

    train_dataset = LoadData(os.path.join(route, 'train', str(
        number)), number_of_channels=g_d_nc)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=True, drop_last=drop_last)

    loss_list_g, loss_list_d = [], []
    for epoch in tqdm(range(0, num_epoch+1), desc='epoch'):
        batch_d_loss, batch_g_loss = 0, 0
        for img, label in train_loader:
            img_number = len(img)
            real_img = img.cuda()
            real_label = torch.ones(img_number).cuda()
            fake_label = torch.zeros(img_number).cuda()

            real_out = d(real_img)
            real_label = real_label.reshape([-1, 1])
            d_loss_real = criterion(real_out, real_label)
            real_scores = real_out

            z = torch.randn(img_number, g_input, 1, 1).cuda()

            fake_img = g(z).detach()
            fake_out = d(fake_img)
            fake_label = fake_label.reshape([-1, 1])
            d_loss_fake = criterion(fake_out, fake_label)
            fake_scores = fake_out

            d_loss = d_loss_real + d_loss_fake

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            z = torch.randn(img_number, g_input, 1, 1).cuda()
            fake_img = g(z)
            output = d(fake_img)
            g_loss = criterion(output, real_label)

            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            batch_d_loss += d_loss
            batch_g_loss += g_loss

        loss_list_g.append(batch_g_loss.item()/len(train_loader))
        loss_list_d.append(batch_d_loss.item()/len(train_loader))

        if epoch % output_loss_Interval_ratio == 0:
            print('\nnumber:{} Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D real: {:.6f},D fake: {:.6f}'.format(
                      number, epoch, num_epoch,
                      batch_d_loss.item()/len(train_loader),
                      batch_g_loss.item()/len(train_loader),
                      real_scores.data.mean(),
                      fake_scores.data.mean()
                  ))

        if not os.path.exists(os.path.join(result_save_path, str(number))):
            os.mkdir(os.path.join(result_save_path, str(number)))

        if epoch % save_model_Interval_ratio == 0:
            save_image(fake_img, os.path.join(result_save_path, str(number),
                                              str(number)+'_fake_epoch'+str(epoch)+'.jpg'))

        for g_or_d, g_d_name in zip([g, d], ['_g_', '_d_']):
            torch.save(g_or_d, os.path.join(result_save_path,
                       str(number), str(number)+g_d_name+'last.pth'))

        plt.plot(range(len(loss_list_g)), loss_list_g, label="g_loss")
        plt.plot(range(len(loss_list_d)), loss_list_d, label="d_loss")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.legend()
        plt.savefig(os.path.join(result_save_path, str(number), 'loss.jpg'))
        plt.clf()

    print('\n')

训练过程中生成网络的效果变化:

一天学会应用GAN扩充数据集(pytorch)

2.3使用生成网络制造fake样本,扩充数据集

注意:修改生成图片保存路径和模型存放路径
扩充代码:

import os
from tqdm import tqdm

import torch
from torchvision.utils import save_image

img_number=500

result_save_path='model/GAN_model'
fakedata_save_path='data/MNIST_fake/train/'

if not os.path.exists(fakedata_save_path):
    os.makedirs(fakedata_save_path)

for number in range(0,10):
    g=torch.load(os.path.join(result_save_path,str(number),str(number)+'_g_last.pth'))
    fake_save_dir=os.path.join(fakedata_save_path,str(number))
    if not os.path.exists(fake_save_dir):
        os.mkdir(fake_save_dir)

    g.eval()
    g_input=next(g.children())[0].in_channels

    for i in tqdm(range(img_number),desc=f'number{number}'):
        z = torch.randn(1,g_input,1,1).cuda()
        fake_img = g(z).detach()
        save_image(fake_img,os.path.join(fake_save_dir,
                        str(number)+'_fake_'+str(i)+'.jpg'))

列举部分真实样本和fake样本

一天学会应用GAN扩充数据集(pytorch)
fake样本:
一天学会应用GAN扩充数据集(pytorch)

三、目录结构展示

1、目录结构图

一天学会应用GAN扩充数据集(pytorch)
说明:MNIST数据集直接下载官方文件,大小52.4M,其它没有后缀的就是文件夹,有后缀的就是对应类型的文件。
运行代码之前先按照这个目录结构创建目录和存放数据集。训练好的模型和生成的图片最终也会存放到model和data对应的目录下。

; 2、utils中的代码

data_reader.py

from torch.utils.data import Dataset
from torchvision.transforms import transforms
from PIL import Image
import os

imgsz=28

class LoadData(Dataset):
    def __init__(self, dir_path, number_of_channels):
        self.imgs_info = [(os.path.join(dir_path,img),dir_path[-1]) for img in os.listdir(dir_path)]

        self.tf = transforms.Compose([

            transforms.Resize((imgsz,imgsz)),

            transforms.ToTensor(),

            transforms.Grayscale(number_of_channels),

            transforms.Normalize([0.5]*number_of_channels, [0.5]*number_of_channels)
            ])

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')
        img = self.tf(img)
        return img,float(label)

    def __len__(self):
        return len(self.imgs_info)

functional_functions.py

import torch.nn as nn

def init_weights(m):
    if hasattr(m,'weight'):
        nn.init.uniform_(m.weight,-0.1,0.1)

network_structure.py

import torch.nn as nn

ndf=64
ngf=64

"""
关于转置卷积:
当padding=0时,卷积核刚好和输入边缘相交一个单位。因此pandding可以理解为卷积核向中心移动的步数。
同时stride也不再是kernel移动的步数,变为输入单元彼此散开的步数,当stride等于1时,中间没有间隔。
"""

class generator(nn.Module):
    def __init__(self,noise_number,number_of_channels):
"""
        noise_number:输入噪声点个数
        number_of_channels:生成图像通道数
"""
        super(generator,self).__init__()
        self.gen = nn.Sequential(

            nn.ConvTranspose2d(noise_number , ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf , 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf ),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf   , number_of_channels, 4, 2, 1, bias=False),
            nn.Tanh()

       )

    def forward(self, x):
        out = self.gen(x)
        return out

class discriminator(nn.Module):
    def __init__(self,number_of_channels):
"""
        number_of_channels:输入图像通道数
"""
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(

            nn.Conv2d(number_of_channels, ndf  , 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf ),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf , ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4 , 1, 4, 1, 0, bias=False),
            nn.Sigmoid()

        )

    def forward(self, x):
        x=self.dis(x).view(x.shape[0],-1)
        return x

class classification_model(nn.Module):
    def __init__(self,n_classes,number_of_channels):
"""
        n_classes:类别数
"""
        super(classification_model,self).__init__()
        self.structure=nn.Sequential(
            nn.Conv2d(number_of_channels, 6, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(16, n_classes, kernel_size=5, stride=1, padding=0),
            nn.Softmax(dim=1)
        )

    def forward(self,x):
        out=self.structure(x)
        out=out.reshape(out.shape[0],-1)
        return out

Original: https://blog.csdn.net/qq_45904885/article/details/119989699
Author: programmer.Mr.Fei,
Title: 一天学会应用GAN扩充数据集(pytorch)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/718323/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球