GAN生成对抗网络—-手写数据实现

目录

GAN—— 以假乱真

GAN 的基本理念其实非常简单,其核心由两个目标互相冲突的神经网络组成,这两个网络会以越来越复杂的方法来”蒙骗”对方。这种情况可以理解为博弈论中的极大极小 博弈树。

在这个过程中,我们想象了两种人:警察和罪犯。让我们来看看他们相互冲突的目标:

[En]

In this process, we imagine two kinds of people: police and criminals. Let’s look at their conflicting goals:

  • 犯罪分子的目标:他的主要目标是想出复杂的假币方法,让警方无法区分假币和真钞。
    [En]

    the criminal’s goal: his main goal is to come up with complex ways to counterfeit money so that the police cannot distinguish between counterfeit money and real money.*

  • 警方的目标:他的主要目标是想出一种复杂的方法来识别货币,这样他就可以区分假币和真钞。
    [En]

    the goal of the police: his main goal is to come up with a complex way to identify currencies so that he can distinguish between counterfeit money and real money.*

随着这个过程不断继续,警察会想出越来越复杂的技术来鉴别假币,罪犯也会想出越来越复杂的技术来伪造货币。这就是 GAN 中”对抗过程”的基本理念。

GAN 充分利用”对抗过程”训练两个神经网络,这两个网络会互相博弈直至达到一种理想的 平衡状态,我们这个例子中的警察和罪犯就相当于这两个神经网络。

其中一个神经网络叫做 生成器网络 G(Z),它会使用输入随机噪声数据,生成和已有数据集非常接近的数据;

另一个神经网络叫 鉴别器网络 D(X),它会以生成的数据作为输入,尝试鉴别出哪些是生成的数据,哪些是真实数据。鉴别器的核心是实现二元分类,输出的结果是输入数据来自真实数据集(和合成数据或虚假数据相对)的概率。

我们在前面所说的 GAN 最终能达到一种理想的平衡状态,是指生成器应该能模拟真实的数据,鉴别器输出的概率应该为 0.5, 即生成的数据和真实数据一致。也就是说,它不确定来自生成器的新数据是真实还是虚假,二者的概率相等。

训练流程

GAN生成对抗网络----手写数据实现

; 环境

  • tensorflow 2.4.1
  • numpy
  • matplotlib

数据集

mnist 手写数字

完整代码

'''
tensorflow 2.4.1
numpy
matplotlib
'''

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
import numpy as np
import time
import cv2 as cv
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Activation,Flatten,Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from tensorflow.keras.layers import LeakyReLU, Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import Adam,RMSprop

import matplotlib.pyplot as plt

class ElapsedTimer(object):
    def __init__(self):
        self.start_time = time.time()
    def elapsed(self,sec):
        if sec < 60:
            return str(sec) + " sec"
        elif sec < (60 * 60):
            return str(sec / 60) + " min"
        else:
            return str(sec / (60 * 60)) + " hr"
    def elapsed_time(self):
        print("Elapsed: %s " % self.elapsed(time.time() - self.start_time) )

class DCGAN(object):
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None
        self.G = None
        self.AM = None
        self.DM = None

    def discriminator(self):
        if self.D:
            return self.D
        self.D = Sequential()
        depth = 64
        dropout = 0.4

        input_shape = (self.img_rows, self.img_cols, self.channel)
"""
        padding = "SAME"输入和输出大小关系:
            输出大小等于输入大小除以步长向上取整

        padding = "VALID"输入和输出大小关系:
            输出大小等于输入大小减去滤波器大小加上1,最后再除以步长
"""
"""
        64个5*5大小的内核,步长为2,🔠input:(14,14,1),padding='same'保证intput和output一样
"""
        self.D.add(Conv2D(64, 5, strides=2, input_shape=input_shape,padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(128, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(256, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(512, 5, strides=1, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(256, 5, strides=1, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Flatten())
        self.D.add(Dense(1))
        self.D.add(Activation('sigmoid'))
        self.D.summary()
        return self.D

    def generator(self):
        if self.G:
            return self.G
        self.G = Sequential()
        dropout = 0.4
        depth = 64+64+64+64
        dim = 7

        self.G.add(Dense(dim*dim*depth, input_dim=100))
"""
        参数作用于mean和variance的计算上, 这里保留了历史batch里的mean和variance值,即 moving_mean和moving_variance,
        借鉴优化算法里的momentum算法将历史batch里的mean和variance的作用延续到当前batch. 一般momentum的值为0.9 , 0.99等.

        多个batch后, 即多个0.9连乘后,最早的batch的影响会变弱.

"""
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))
        self.G.add(Reshape((dim, dim, depth)))
        self.G.add(Dropout(dropout))

        self.G.add(UpSampling2D())
"""
        输入图像通过卷积操作提取特征后,输出的尺寸常会变小,而有时我们需要将图像恢复到原来的尺寸以便进行进一步的计算(比如:图像的语义分割),
        那么我们需要实现图像由小分辨率到大分辨率的映射的操作,叫做上采样(Upsample)。
"""
        self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))

        self.G.add(UpSampling2D())
        self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))

        self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))

        self.G.add(Conv2DTranspose(1, 5, padding='same'))
        self.G.add(Activation('sigmoid'))
        self.G.summary()
        return self.G

    def discriminator_model(self):
        if self.DM:
            return self.DM
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        self.DM = Sequential()
        self.DM.add(self.discriminator())

        self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.DM

    def adversarial_model(self):
        if self.AM:
            return self.AM
        optimizer =RMSprop(lr=0.0001, decay=3e-8)
        self.AM = Sequential()
        self.AM.add(self.generator())
        self.AM.add(self.discriminator())

        self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.AM

class MNIST_DCGAN(object):
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channel = 1
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        X_train = X_train / 255.0
        self.x_train = X_train.reshape(-1, 28, 28, 1).astype(np.float32)

        self.DCGAN = DCGAN()
        self.discriminator =  self.DCGAN.discriminator_model()
        self.adversarial = self.DCGAN.adversarial_model()
        self.generator = self.DCGAN.generator()
    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None
        if save_interval>0:
            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
        for i in range(train_steps):

""""
                第一轮,由于是没有权重,随机噪声
                再后我们对判别器进行训练之后,loss更新,生成器网络权重更新
"""
            images_train = self.x_train[np.random.randint(0,self.x_train.shape[0], size=batch_size), :, :, :]
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            images_fake = self.generator.predict(noise)

"""
            图像保存 每5轮保存一次生成器所生成的image
"""
            if i%5==0:
                plt.figure(figsize=(24, 24))
                for j in range(16):
                    plt.subplot(4, 4, j + 1)
                    image = images_fake[j, :, :, :]
                    image = np.reshape(image, [28,28])
                    plt.imshow(image, cmap='gray')
                    plt.axis('off')
                    plt.tight_layout()
                filename = './g/img_{}'.format(i)

                plt.close('all')

""""
            在鉴别器的训练过程中,它显示为真实图像,并用于计算鉴别器损耗。
             它对来自生成器的真实和伪造图像进行分类,如果对任何图像进行了不正确分类,则鉴别器损失将对鉴别器进行惩罚。
             通过反向传播,鉴别器更新其权重

             类似地,为生成器提供了噪声输入以生成伪图像。 这些图像被提供给鉴别器,并且发生器损失惩罚了发生器以产生鉴别器网络分类为伪造的样本。
              权重通过从鉴别器到生成器的反向传播进行更新
"""
            x = np.concatenate((images_train, images_fake))
            print('4',x.shape)
            y = np.ones([2*batch_size, 1])
            y[batch_size:, :] = 0

            d_loss = self.discriminator.train_on_batch(x, y)

"""
            核心
"""
            y = np.ones([batch_size, 1])
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            a_loss = self.adversarial.train_on_batch(noise, y)

            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)

            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png" % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]

        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()

if __name__ == '__main__':
    mnist_dcgan = MNIST_DCGAN()
    timer = ElapsedTimer()
    mnist_dcgan.train(train_steps=10000, batch_size=128, save_interval=1000)
    timer.elapsed_time()
    mnist_dcgan.plot_images(fake=True)
    mnist_dcgan.plot_images(fake=False, save2file=True)

结果展示

GAN生成对抗网络----手写数据实现
【参考文献】
https://www.cnblogs.com/dereen/p/gan.html
https://zhuanlan.zhihu.com/p/43047326
https://www.zhihu.com/question/306213462

Original: https://blog.csdn.net/qq_44936246/article/details/120970272
Author: 醉公子~
Title: GAN生成对抗网络—-手写数据实现

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514692/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球