图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

*
构建WGAN(Wasserstein GAN)

+ Wasserstein loss介绍
+ 1-Lipschitz约束的实现
+ 训练过程
+ 实现梯度惩罚(WGAN-GP)
完整代码

构建WGAN(Wasserstein GAN)

自GAN提出以来,许多论文尝试通过使用启发式方法(例如尝试不同的网络体系结构,超参数和优化器)来解决GAN训练的不稳定性。随着Wasserstein GAN(WGAN)的提出,这一问题的研究得到了重大突破。
WGAN缓解甚至消除了许多GAN训练过程中存在的问题。相较于原始GAN的其根本的改进是对损失函数的修改。从理论上讲,如果两个分布不相交,则JS散度将不再是连续的,因此将不可微,从而导致梯度为零。 WGAN通过使用一个新的损失函数来解决此问题,该函数在任何地方都是连续且可微的!

Wasserstein loss介绍

对于原始GAN的目标函数,我们都已经耳熟能详,在此简单进行回顾:
m i n G m a x D V ( D , G ) = E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_Gmax_DV(D,G)=E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[log(1-D(G(z)))]m i n G ​m a x D ​V (D ,G )=E x ∼p t a t a ​(x )​[l o g D (x )]+E z ∼p z ​(z )​[l o g (1 −D (G (z )))]
其中,D D D表示鉴别器,G G G表示生成器,x x x表示真实数据,z z z表示潜在变量。
通过转换上述形式,可以得到以下值函数形式:

[En]

By converting the above form, you can get the following value function form:

E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g D ( G ( z ) ) ] E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[logD(G(z))]E x ∼p t a t a ​(x )​[l o g D (x )]+E z ∼p z ​(z )​[l o g D (G (z ))]
WGAN使用一种新的损失函数,称为推土机距离或Wasserstein距离。它用于度量将一种分布转换为另一种分布所需的距离或工作量。从数学上讲,这是真实图像与生成图像之间每个联合分布的最小距离,WGAN的值函数变为:
E x ∼ p d a t a ( x ) [ D ( x ) ] − E z ∼ p z ( z ) [ D ( G ( z ) ) ] E_{x\sim p_{data}(x)}[D(x)]-E_{z\sim p_z(z)}[D(G(z))]E x ∼p d a t a ​(x )​[D (x )]−E z ∼p z ​(z )​[D (G (z ))]
我们将使用该函数来推导损失函数。第一项可以写成如下:

[En]

We will use this function to derive the loss function. The first item can be written as follows:

− 1 N ∑ i = 1 N y i D ( x ) -\frac1N\sum_{i=1}^Ny_iD(x)−N 1 ​i =1 ∑N ​y i ​D (x )
这是鉴别器输出的平均值乘以-1。我们通过使用y i y_i y i ​作为标签,其中+1代表真实图像,而-1代表虚假图像。因此,我们可以将Wasserstein损失实现为TensorFlow Keras自定义损失函数,如下所示:

def wasserstein_loss(self, y_true, y_pred):
    w_loss = -tf.reduce_mean(y_true*y_pred)
    return w_loss

它旨在使真实图像相对于伪图像的得分最大化。因此,在WGAN中,鉴别器也被称为评论家(critic)。
但是由于WGAN删除了鉴别器的输出中sigmoid激活函数。因此,评论家的预测是无限的,需要通1-Lipschitz进行约束。

1-Lipschitz约束的实现

Wasserstein损失中提到的数学假设是1-Lipschitz函数。我们说评论家D(x)如果满足以下不等式,则为1-Lipschitz:
∣ D ( x 1 ) − D ( x 2 ) ∣ ≤ ∣ x 1 − x 2 ∣ |D(x_1)-D(x_2)|\leq|x_1-x_2|∣D (x 1 ​)−D (x 2 ​)∣≤∣x 1 ​−x 2 ​∣
对于两个图像x 1 x_1 x 1 ​和x 2 x_2 x 2 ​,评论家的输出差异的绝对值必须小于或等于其平均逐像素差的绝对值。换句话说,对于不同的图像,无论是真实图像还是伪造图像,评论家的输出不应有太大差异。当WGAN提出时,作者无法想到适当的实施方式来实现此不等式。因此,他们想出了一个办法,就是将评论家的权重降低到一些很小的值。这样,层的输出以及最终评论家的输出都被限制在一些较小的值上。在WGAN论文中,权重被限制在[-0.01,0.01]的范围内。
裁剪体重可以通过两种方式实现。一种方法是编写一个自定义约束函数,并在实例化新层时使用它,如下所示:

[En]

Weight tailoring can be achieved in two ways. One way is to write a custom constraint function and use it when instantiating a new layer, as follows:

class WeightsClip(tf.keras.constraints.Constraint):
    def __init__(self, min_value=-0.01, max_value=0.01):
        self.min_value = min_value
        self.max_value = max_value
    def __call__(self, w):
        return tf.clip_by_value(w, self.min, self.max_value)

然后可以将该函数传递给接受受约束函数的层,如下所示:

[En]

You can then pass the function to the layer that accepts the constrained function, as follows:

model = tf.keras.Sequential(name='critics')
model.add(Conv2D(16, 3, strides=2, padding='same',
                kernel_constraint=WeightsClip(),
                bias_constraint=WeightsClip()))
model.add(BatchNormalization(
beta_constraint=WeightsClip(),
gamma_constraint=WeightsClip()))

然而,在每个层创建过程中添加约束代码会使代码变得臃肿。因为我们不需要选择要裁剪的层,所以我们可以使用循环来读取权重,并在裁剪后将其写回,如下所示:

[En]

However, adding constraint code to each layer creation process makes the code bloated. Since we do not need to pick the layer to crop, we can use a loop to read the weight and write it back after clipping, as shown below:

对于comment.layers中的层:

for layer in critic.layers:
    weights = layer.get_weights()
    weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
    layer.set_weights(weights)

训练过程

在原始GAN理论中,应该在生成器之前对鉴别器进行训练。但在实践中,由于鉴别器能更快的训练,因此鉴别器的梯度将逐渐消失。有了Wasserstein损失函数后,可以在任何地方推导梯度,将不必担心评论家相较生成器过于强大。
因此,在WGAN中,对于生成器的每一个训练步骤,评论家都会接受五次训练。为了做到这一点,我们将评论家训练步骤写为一个单独的函数,然后可以循环多次:

for _ in range(self.n_critic):
    real_images = next(data_generator)
    critic_loss = self.train_critic(real_images,     batch_size)

发电机的培训步骤:

[En]

Training steps for the generator:

self.critic = self.build_critic()
self.critic.trainable = False
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss, optimizer = RMSprop(3e-4))
self.critic.trainable = True

在前面的代码中,通过设置 trainable = False冻结了评论者层,并将其链接到生成器以创建一个新模型并进行编译。之后,我们可以将评论家设置为可训练,这不会影响我们已经编译的模型。
我们使用 train_on_batch()API执行单个训练步骤,该步骤将自动进行前向计算,损失计算,反向传播和权重更新:

g_loss = self.model.train_on_batch(g_input,   real_labels)

下图显示了WGAN生成器体系结构:

图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

下图显示了WGAN评论家体系结构:

图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

尽管较原始GAN方面有所改进,但训练WGAN十分困难,并且所产生的图像质量并不比原始GAN更好。接下来,将实现WGAN的变体WGAN-GP,该变体训练速度更快,并产生更清晰的图像。

实现梯度惩罚(WGAN-GP)

正如WGAN作者所承认的那样,权重裁剪并不是实施Lipschitz约束的理想方法。其有两个缺点:网络容量使用不足和梯度爆炸/消失。当我们裁剪权重时,我们也限制了评论家的学习能力。权重裁剪迫使网络仅学习简单特征。因此,神经网络的容量变得未被充分利用。其次,裁剪值需要仔细调整。如果设置得太高,梯度会爆炸,从而违反了Lipschitz约束。如果设置得太低,则随着网络反向传播,梯度将消失。同样,权重裁剪会将梯度推到两个极限值,如下图所示:

图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP
因此,提出了梯度惩罚(GP)来代替权重裁剪以强制实施Lipschitz约束,如下所示:
G r a d i e n t p e n a l t y = λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] Gradient\ penalty = \lambda E\hat x[(\lVert \nabla _{\hat x}D(\hat x) \rVert_2-1)^2]G r a d i e n t p e n a l t y =λE x ^[(∥∇x ^​D (x ^)∥2 ​−1 )2 ]
我们将查看公式中的每个变量,并在代码中实现它们。
[En]

We will look at each variable in the equation and implement them in the code.

我们通常使用x x x表示真实图像,但是现在方程式中有一个x ^ \hat x x ^。x ^ \hat x x ^是真实图像和伪图像之间的逐点插值。从[0,1]的均匀分布中得出图像比率(epsilon):

epsilon = tf.random.uniform((batch_size,1,1,1))
interpolates = epsilon*real_images + (1-epsilon)*fake_images

根据WGAN-GP论文,就我们的目的而言,我们可以这样理解,因为梯度来自真实图像和伪造图像的混合,因此我们不需要分别计算真实和伪造图像的损失。
∇ x ^ D ( x ^ ) \nabla _{\hat x}D(\hat x)∇x ^​D (x ^)项是评论家输出相对于插值的梯度。我们可以再次使用 tf.GradientTape()来获取梯度:

with tf.GradientTape() as gradient_tape:
    gradient_tape.watch(interpolates)
    critic_interpolates = self.critic(interpolates)
    gradient_d = gradient_tape.gradient(critic_interpolates, [interpolates])

下一步是计算L2范数:
∥ ∇ x ^ D ( x ^ ) ∥ 2 \lVert \nabla _{\hat x}D(\hat x) \rVert_2 ∥∇x ^​D (x ^)∥2 ​
我们将每个值平方,将它们相加,然后求出平方根:

[En]

We square each value, add them together, and then find the square root:

grad_loss = tf.square(grad)
grad_loss = tf.reduce_sum(grad_loss, axis=np.arange(1, len(grad_loss.shape)))
grad_loss = tf.sqrt(grad_loss)

在执行 tf.reduce_sum()时,我们排除了轴上的第一维,因为该维是batch大小。惩罚旨在使梯度范数接近1,这是计算梯度损失的最后一步:

grad_loss = tf.reduce_mean(tf.square(grad_loss - 1))

等式中的λ λλ是梯度惩罚与其他评论家损失的比率,在本这里中设置为10。现在,我们将所有评论家损失和梯度惩罚添加到反向传播并更新权重:

total_loss = loss_real + loss_fake + LAMBDA * grad_loss
gradients = total_tape.gradient(total_loss, self.critic.variables)
self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))

这就是需要添加到WGAN中以使其成为WGAN-GP的所有内容。不过,需要删除以下部分:

  1. 权重裁剪
  2. 评论家中的批标准化

梯度惩罚是针对每个输入单独惩罚评论员的梯度范数。但是,批次规格化使用批次统计信息更改梯度。为了避免这个问题,批次标准化被从批评者中移除。

[En]

The gradient penalty is to punish the commentator’s gradient norm independently for each input. However, batch normalization changes the gradient with the batch statistics. To avoid this problem, batch normalization is removed from critics.

评论家体系结构与WGAN相同,但不包括批规范化:

图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

以下是经过训练的WGAN-GP生成的样本:

图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

它们看起来清晰漂亮,非常类似于Fashion-MNIST数据集中的样本。训练非常稳定,很快就收敛了!

完整代码


import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.metrics import binary_accuracy

import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')
print("Tensorflow", tf.__version__)

ds_train, ds_info = tfds.load('fashion_mnist', split='train',shuffle_files=True,with_info=True)
fig = tfds.show_examples(ds_train, ds_info)

batch_size = 64
image_shape = (32, 32, 1)

def preprocess(features):
    image = tf.image.resize(features['image'], image_shape[:2])
    image = tf.cast(image, tf.float32)
    image = (image-127.5)/127.5
    return image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size, drop_remainder=True).repeat()

train_num = ds_info.splits['train'].num_examples
train_steps_per_epoch = round(train_num/batch_size)
print(train_steps_per_epoch)

"""
WGAN
"""
class WGAN():
    def __init__(self, input_shape):

        self.z_dim = 128
        self.input_shape = input_shape

        self.loss_critic_real = {}
        self.loss_critic_fake = {}
        self.loss_critic = {}
        self.loss_generator = {}

        self.n_critic = 5
        self.critic = self.build_critic()
        self.critic.trainable = False

        self.optimizer_critic = RMSprop(5e-5)

        self.generator = self.build_generator()
        critic_output = self.critic(self.generator.output)
        self.model = Model(self.generator.input, critic_output)
        self.model.compile(loss = self.wasserstein_loss,
                           optimizer =  RMSprop(5e-5))
        self.critic.trainable = True

    def wasserstein_loss(self, y_true, y_pred):

        w_loss = -tf.reduce_mean(y_true*y_pred)

        return w_loss

    def build_generator(self):

        DIM = 128
        model = tf.keras.Sequential(name='Generator')

        model.add(layers.Input(shape=[self.z_dim]))

        model.add(layers.Dense(4*4*4*DIM))
        model.add(layers.BatchNormalization())
        model.add(layers.ReLU())
        model.add(layers.Reshape((4,4,4*DIM)))

        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
        model.add(layers.Conv2D(2*DIM, 5, padding='same'))
        model.add(layers.BatchNormalization())
        model.add(layers.ReLU())

        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
        model.add(layers.Conv2D(DIM, 5, padding='same'))
        model.add(layers.BatchNormalization())
        model.add(layers.ReLU())

        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
        model.add(layers.Conv2D(image_shape[-1], 5, padding='same', activation='tanh'))

        return model

    def build_critic(self):

        DIM = 128
        model = tf.keras.Sequential(name='critics')

        model.add(layers.Input(shape=self.input_shape))

        model.add(layers.Conv2D(1*DIM, 5, strides=2, padding='same'))
        model.add(layers.LeakyReLU(0.2))

        model.add(layers.Conv2D(2*DIM, 5, strides=2, padding='same'))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(0.2))

        model.add(layers.Conv2D(4*DIM, 5, strides=2, padding='same'))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(0.2))

        model.add(layers.Flatten())
        model.add(layers.Dense(1))

        return model

    def train_critic(self, real_images, batch_size):

        real_labels = tf.ones(batch_size)
        fake_labels = -tf.ones(batch_size)

        g_input = tf.random.normal((batch_size, self.z_dim))
        fake_images = self.generator.predict(g_input)

        with tf.GradientTape() as total_tape:

            pred_fake = self.critic(fake_images)
            pred_real = self.critic(real_images)

            loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
            loss_real = self.wasserstein_loss(real_labels, pred_real)

            total_loss = loss_fake + loss_real

            gradients = total_tape.gradient(total_loss, self.critic.trainable_variables)

            self.optimizer_critic.apply_gradients(zip(gradients, self.critic.trainable_variables))

        for layer in self.critic.layers:
            weights = layer.get_weights()
            weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
            layer.set_weights(weights)

        return loss_fake, loss_real

    def train(self, data_generator, batch_size, steps, interval=200):

        val_g_input = tf.random.normal((batch_size, self.z_dim))
        real_labels = tf.ones(batch_size)

        for i in range(steps):
            for _ in range(self.n_critic):
                real_images = next(data_generator)
                loss_fake, loss_real = self.train_critic(real_images, batch_size)
                critic_loss = loss_fake + loss_real

            g_input = tf.random.normal((batch_size, self.z_dim))
            g_loss = self.model.train_on_batch(g_input, real_labels)

            self.loss_critic_real[i] = loss_real.numpy()
            self.loss_critic_fake[i] = loss_fake.numpy()
            self.loss_critic[i] = critic_loss.numpy()
            self.loss_generator[i] = g_loss

            if i%interval == 0:
                msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f}  critic_real {:.4f}"\
                .format(i, g_loss, critic_loss, loss_fake, loss_real)
                print(msg)

                fake_images = self.generator.predict(val_g_input)
                self.plot_images(fake_images)
                self.plot_losses()

    def plot_images(self, images):
        grid_row = 1
        grid_col = 8
        f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
        for row in range(grid_row):
            for col in range(grid_col):
                if self.input_shape[-1]==1:
                    axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
                else:
                    axarr[col].imshow(images[col]*0.5+0.5)
                axarr[col].axis('off')
        plt.show()

    def plot_losses(self):
        fig, (ax1, ax2) = plt.subplots(2, sharex=True)
        fig.set_figwidth(10)
        fig.set_figheight(6)
        ax1.plot(list(self.loss_critic.values()), label='Critic loss', alpha=0.7)
        ax1.set_title("Critic loss")
        ax2.plot(list(self.loss_generator.values()), label='Generator loss', alpha=0.7)
        ax2.set_title("Generator loss")

        plt.xlabel('Steps')
        plt.show()

wgan = WGAN(image_shape)
wgan.generator.summary()

wgan.critic.summary()

wgan.train(iter(ds_train), batch_size, 2000, 100)

z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)

wgan.generator.save_weights('./wgan_models/wgan_fashion_minist.weights')

"""
WGAN_GP
"""
class WGAN_GP():
    def __init__(self, input_shape):

        self.z_dim = 128
        self.input_shape = input_shape

        self.n_critic = 5
        self.penalty_const = 10
        self.critic = self.build_critic()
        self.critic.trainable = False

        self.optimizer_critic = Adam(1e-4, 0.5, 0.9)

        self.generator = self.build_generator()
        critic_output = self.critic(self.generator.output)
        self.model = Model(self.generator.input, critic_output)
        self.model.compile(loss=self.wasserstein_loss, optimizer=Adam(1e-4, 0.5, 0.9))

    def wasserstein_loss(self, y_true, y_pred):

        w_loss = -tf.reduce_mean(y_true*y_pred)

        return w_loss

    def build_generator(self):

        DIM = 128
        model = Sequential([
            layers.Input(shape=[self.z_dim]),

            layers.Dense(4*4*4*DIM),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.Reshape((4,4,4*DIM)),

            layers.UpSampling2D((2,2), interpolation='bilinear'),
            layers.Conv2D(2*DIM, 5, padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.UpSampling2D((2,2), interpolation='bilinear'),
            layers.Conv2D(2*DIM, 5, padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.UpSampling2D((2,2), interpolation='bilinear'),
            layers.Conv2D(image_shape[-1], 5, padding='same', activation='tanh')
        ],name='Generator')

        return model

    def build_critic(self):

        DIM = 128
        model = Sequential([
            layers.Input(shape=self.input_shape),

            layers.Conv2D(1*DIM, 5, strides=2, padding='same', use_bias=False),
            layers.LeakyReLU(0.2),

            layers.Conv2D(2*DIM, 5, strides=2, padding='same', use_bias=False),
            layers.LeakyReLU(0.2),

            layers.Conv2D(4*DIM, 5, strides=2, padding='same', use_bias=False),
            layers.LeakyReLU(0.2),

            layers.Flatten(),
            layers.Dense(1)
        ], name='critics')

        return model

    def gradient_loss(self, grad):
        loss = tf.square(grad)
        loss = tf.reduce_sum(loss, axis=np.arange(1, len(loss.shape)))
        loss = tf.sqrt(loss)
        loss = tf.reduce_mean(tf.square(loss - 1))
        loss = self.penalty_const * loss
        return loss

    def train_critic(self, real_images, batch_size):
        real_labels = tf.ones(batch_size)
        fake_labels = -tf.ones(batch_size)

        g_input = tf.random.normal((batch_size, self.z_dim))
        fake_images = self.generator.predict(g_input)

        with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:

            pred_fake = self.critic(fake_images)
            pred_real = self.critic(real_images)

            loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
            loss_real = self.wasserstein_loss(real_labels, pred_real)

            epsilon = tf.random.uniform((batch_size, 1, 1, 1))
            interpolates = epsilon * real_images + (1-epsilon) * fake_images
            gradient_tape.watch(interpolates)

            critic_interpolates = self.critic(interpolates)
            gradients_interpolates = gradient_tape.gradient(critic_interpolates, [interpolates])
            gradient_penalty = self.gradient_loss(gradients_interpolates)

            total_loss = loss_fake + loss_real + gradient_penalty

            gradients = total_tape.gradient(total_loss, self.critic.variables)

            self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))
        return loss_fake, loss_real, gradient_penalty

    def train(self, data_generator, batch_size, steps, interval=100):
        val_g_input = tf.random.normal((batch_size, self.z_dim))
        real_labels = tf.ones(batch_size)

        for i in range(steps):
            for _ in range(self.n_critic):
                real_images = next(data_generator)
                loss_fake, loss_real, gradient_penalty = self.train_critic(real_images, batch_size)
                critic_loss = loss_fake + loss_real + gradient_penalty

            g_input = tf.random.normal((batch_size, self.z_dim))
            g_loss = self.model.train_on_batch(g_input, real_labels)
            if i%interval == 0:
                msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f}  critic_real {:.4f} penalty {:.4f}".format(i, g_loss, critic_loss, loss_fake, loss_real, gradient_penalty)
                print(msg)

                fake_images = self.generator.predict(val_g_input)
                self.plot_images(fake_images)

    def plot_images(self, images):
        grid_row = 1
        grid_col = 8
        f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
        for row in range(grid_row):
            for col in range(grid_col):
                if self.input_shape[-1]==1:
                    axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
                else:
                    axarr[col].imshow(images[col]*0.5+0.5)
                axarr[col].axis('off')
        plt.show()

wgan = WGAN_GP(image_shape)
wgan.train(iter(ds_train), batch_size, 5000, 100)

wgan.model.summary()

wgan.critic.summary()

z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)

Original: https://blog.csdn.net/LOVEmy134611/article/details/117230911
Author: 盼小辉丶
Title: 图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/7679/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

免费咨询
免费咨询
扫码关注
扫码关注
联系站长

站长Johngo!

大数据和算法重度研究者!

持续产出大数据、算法、LeetCode干货,以及业界好资源!

2022012703491714

微信来撩,免费咨询:xiaozhu_tec

分享本页
返回顶部
最近整理资源【免费获取】:   👉 程序员最新必读书单  | 👏 互联网各方向面试题下载 | ✌️计算机核心资源汇总