【生成对抗网络】GAN入门与代码实现(二)

文章目录

*
1 导包
2 数据准备
3 生成器模型
4 判别器模型
5 编写损失函数,定义优化器
6 获取模型&定义训练批次函数
7 定义可视化方法
8 主训练方法
9 开始训练
10 训练结果

生成对抗网络系列
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(二)
【生成对抗网络】基于DCGAN的二次元人物头像生成(TensorFlow2)
【生成对抗网络】ACGAN的代码实现

上篇博客:【生成对抗网络】GAN入门与代码实现(一)
本篇主要介绍简单GAN的另一种实现方法(不使用卷积),依然使用TensorFlow2进行搭建,主要运用了TensorFlow2中的求导机制进行自定义训练,自由度更高。对比上篇博客中的实现方法可加深对GAN的编写理解。

1 导包

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline

2 数据准备

我们使用MNIST手写数据集作为训练生成的数据。


(train_images,train_labels),(_,_) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images - 127.5)/127.5

定义相关参数。

BATCH_SIZE = 300
BUFFER_SIZE = 60000

EPOCHS = 300
noise_dim = 100

将原数据创建为Dataset数据,便于训练。

datasets = tf.data.Dataset.from_tensor_slices(train_images)

datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

for item in datasets:
    print(item.shape)
    print(type(item))
    break

3 生成器模型

输入100维的随机向量,输出一张(28,28,1)维的图片。

def generator_model():

    generator = keras.models.Sequential([
        keras.layers.Input(shape=(100,)),
        keras.layers.Dense(256),
        keras.layers.LeakyReLU(alpha = 0.2),
        keras.layers.BatchNormalization(momentum = 0.8),
        keras.layers.Dense(512),
        keras.layers.LeakyReLU(alpha = 0.2),
        keras.layers.BatchNormalization(momentum = 0.8),
        keras.layers.Dense(1024),
        keras.layers.LeakyReLU(alpha = 0.2),
        keras.layers.BatchNormalization(momentum = 0.8),
        keras.layers.Dense(np.prod((28,28,1)),activation='tanh'),
        keras.layers.Reshape((28,28,1))
    ])

    return generator

4 判别器模型

输入图片,输出1维的判定结果(最后没有使用激活函数)。

def discriminator_model():

    discriminator = keras.models.Sequential([
        keras.layers.Flatten(),
        keras.layers.Dense(512),
        keras.layers.LeakyReLU(alpha = 0.2),
        keras.layers.Dense(256),
        keras.layers.LeakyReLU(alpha = 0.2),
        keras.layers.Dense(1)
    ])

    return discriminator

5 编写损失函数,定义优化器

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

判别器损失:

鉴别器的目标是将真实图像区分为1,将生成的图像区分为0,因此损失函数将真实图像与1进行比较,生成0的图像,并计算损失。

[En]

The goal of the discriminator is to distinguish the real picture to 1 and the generated picture to 0, so the loss function compares the real picture with 1, generates the picture with 0, and calculates the loss.


def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss + fake_loss

生成器损失:

生成器的目标是使自身生成的图像在鉴别器中被判定为1(实数),因此需要将损失函数与1进行比较。

[En]

The goal of the generator is to make the image generated by itself be judged as 1 (real) in the discriminator, so the loss function needs to be compared with 1.


def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out),fake_out)
    return fake_loss

定义优化器:

learning_rate:0.0002

beta_1:0.5


generator_opt = tf.keras.optimizers.Adam(2e-4,0.5)
discriminator_opt = tf.keras.optimizers.Adam(2e-4,0.5)

参数简介:
learning_rate:一个张量,浮点值,或者是一个tf.keras.optimizer .schedules时间表。LearningRateSchedule,或者一个不带参数并返回要使用的实际值的可调用对象,即学习速率。默认为0.001。
beta_1:一个浮点值或一个常量浮点张量,或者一个不带参数并返回实际值的可调用对象。一阶矩的指数衰减率估计。默认为0.9

6 获取模型&定义训练批次函数

generator = generator_model()
discriminator = discriminator_model()

定义训练函数,使用Tensorflow中的自动求导与根据梯度更新参数的方法来训练生成器与判别器。

@tf.function

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE,noise_dim])

    with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
        real_out = discriminator(images,training = True)
        gen_image = generator(noise,training = True)
        fake_out = discriminator(gen_image,training = True)

        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out,fake_out)

    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
    gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)

    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

7 定义可视化方法

为了在训练的过程中查看生成器输出图片的效果,我们定义6个100维度的随机数来检测训练过程中的生成器模型,使用matlibplot中的方法绘制图片。

num_example_to_generate = 6

seed = np.random.normal(0,1,(num_example_to_generate,noise_dim))

def generate_plot_image(test_noise):

    pre_image = generator(test_noise,training = False)

    fig = plt.figure(figsize=(16,3))
    for i in range(pre_image.shape[0]):
        plt.subplot(1,6,i+1)
        plt.imshow((pre_image[i,:,:,:] + 1)/2)
        plt.axis('off')
    plt.show()

8 主训练方法

训练epochs次,每次epoch中从dataset依次取出batch个数据调用步骤6中的方法进行训练,每次epoch结束后调用步骤7中的方法绘制几张图片查看生成器的生成效果。


def train(dataset,epochs):
    for epoch in range(1,epochs+1):
        print("epoch:",epoch)
        for image_batch in dataset:
            train_step(image_batch)
            print(".",end="")
        generate_plot_image(seed)

9 开始训练

train(datasets,EPOCHS)

10 训练结果

epoch 1:

【生成对抗网络】GAN入门与代码实现(二)
epoch 10:
【生成对抗网络】GAN入门与代码实现(二)
epoch 20:
【生成对抗网络】GAN入门与代码实现(二)
epoch 50:
【生成对抗网络】GAN入门与代码实现(二)
epoch 100:
【生成对抗网络】GAN入门与代码实现(二)
epoch 290:
【生成对抗网络】GAN入门与代码实现(二)

Original: https://blog.csdn.net/AwesomeP/article/details/124530118
Author: 宛如近在咫尺
Title: 【生成对抗网络】GAN入门与代码实现(二)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514548/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球