# 图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

## 构建WGAN（Wasserstein GAN）

WGAN缓解甚至消除了许多GAN训练过程中存在的问题。相较于原始GAN的其根本的改进是对损失函数的修改。从理论上讲，如果两个分布不相交，则JS散度将不再是连续的，因此将不可微，从而导致梯度为零。 WGAN通过使用一个新的损失函数来解决此问题，该函数在任何地方都是连续且可微的！

### Ｗasserstein loss介绍

m i n G m a x D V ( D , G ) = E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_Gmax_DV(D,G)=E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[log(1-D(G(z)))]m i n G ​m a x D ​V (D ,G )=E x ∼p t a t a ​(x )​[l o g D (x )]+E z ∼p z ​(z )​[l o g (1 −D (G (z )))]

[En]

By converting the above form, you can get the following value function form:

E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g D ( G ( z ) ) ] E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[logD(G(z))]E x ∼p t a t a ​(x )​[l o g D (x )]+E z ∼p z ​(z )​[l o g D (G (z ))]
WGAN使用一种新的损失函数，称为推土机距离或Wasserstein距离。它用于度量将一种分布转换为另一种分布所需的距离或工作量。从数学上讲，这是真实图像与生成图像之间每个联合分布的最小距离，WGAN的值函数变为：
E x ∼ p d a t a ( x ) [ D ( x ) ] − E z ∼ p z ( z ) [ D ( G ( z ) ) ] E_{x\sim p_{data}(x)}[D(x)]-E_{z\sim p_z(z)}[D(G(z))]E x ∼p d a t a ​(x )​[D (x )]−E z ∼p z ​(z )​[D (G (z ))]

[En]

We will use this function to derive the loss function. The first item can be written as follows:

− 1 N ∑ i = 1 N y i D ( x ) -\frac1N\sum_{i=1}^Ny_iD(x)−N 1 ​i =1 ∑N ​y i ​D (x )

def wasserstein_loss(self, y_true, y_pred):
w_loss = -tf.reduce_mean(y_true*y_pred)
return w_loss


### 1-Lipschitz约束的实现

Wasserstein损失中提到的数学假设是1-Lipschitz函数。我们说评论家D(x)如果满足以下不等式，则为1-Lipschitz：
∣ D ( x 1 ) − D ( x 2 ) ∣ ≤ ∣ x 1 − x 2 ∣ |D(x_1)-D(x_2)|\leq|x_1-x_2|∣D (x 1 ​)−D (x 2 ​)∣≤∣x 1 ​−x 2 ​∣

[En]

Weight tailoring can be achieved in two ways. One way is to write a custom constraint function and use it when instantiating a new layer, as follows:

class WeightsClip(tf.keras.constraints.Constraint):
def __init__(self, min_value=-0.01, max_value=0.01):
self.min_value = min_value
self.max_value = max_value
def __call__(self, w):
return tf.clip_by_value(w, self.min, self.max_value)


[En]

You can then pass the function to the layer that accepts the constrained function, as follows:

model = tf.keras.Sequential(name='critics')
kernel_constraint=WeightsClip(),
bias_constraint=WeightsClip()))
beta_constraint=WeightsClip(),
gamma_constraint=WeightsClip()))


[En]

However, adding constraint code to each layer creation process makes the code bloated. Since we do not need to pick the layer to crop, we can use a loop to read the weight and write it back after clipping, as shown below:

for layer in critic.layers:
weights = layer.get_weights()
weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
layer.set_weights(weights)


### 训练过程

for _ in range(self.n_critic):
real_images = next(data_generator)
critic_loss = self.train_critic(real_images,     batch_size)


[En]

Training steps for the generator:

self.critic = self.build_critic()
self.critic.trainable = False
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss, optimizer = RMSprop(3e-4))
self.critic.trainable = True


g_loss = self.model.train_on_batch(g_input,   real_labels)


### 实现梯度惩罚（WGAN-GP）

G r a d i e n t p e n a l t y = λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] Gradient\ penalty = \lambda E\hat x[(\lVert \nabla _{\hat x}D(\hat x) \rVert_2-1)^2]G r a d i e n t p e n a l t y =λE x ^[(∥∇x ^​D (x ^)∥2 ​−1 )2 ]

[En]

We will look at each variable in the equation and implement them in the code.

epsilon = tf.random.uniform((batch_size,1,1,1))
interpolates = epsilon*real_images + (1-epsilon)*fake_images


∇ x ^ D ( x ^ ) \nabla _{\hat x}D(\hat x)∇x ^​D (x ^)项是评论家输出相对于插值的梯度。我们可以再次使用 tf.GradientTape()来获取梯度：

with tf.GradientTape() as gradient_tape:
critic_interpolates = self.critic(interpolates)


∥ ∇ x ^ D ( x ^ ) ∥ 2 \lVert \nabla _{\hat x}D(\hat x) \rVert_2 ∥∇x ^​D (x ^)∥2 ​

[En]

We square each value, add them together, and then find the square root:

grad_loss = tf.square(grad)


grad_loss = tf.reduce_mean(tf.square(grad_loss - 1))


total_loss = loss_real + loss_fake + LAMBDA * grad_loss


1. 权重裁剪
2. 评论家中的批标准化

[En]

The gradient penalty is to punish the commentator’s gradient norm independently for each input. However, batch normalization changes the gradient with the batch statistics. To avoid this problem, batch normalization is removed from critics.

## 完整代码


import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import binary_accuracy

import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')
print("Tensorflow", tf.__version__)

fig = tfds.show_examples(ds_train, ds_info)

batch_size = 64
image_shape = (32, 32, 1)

def preprocess(features):
image = tf.image.resize(features['image'], image_shape[:2])
image = tf.cast(image, tf.float32)
image = (image-127.5)/127.5
return image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size, drop_remainder=True).repeat()

train_num = ds_info.splits['train'].num_examples
train_steps_per_epoch = round(train_num/batch_size)
print(train_steps_per_epoch)

"""
WGAN
"""
class WGAN():
def __init__(self, input_shape):

self.z_dim = 128
self.input_shape = input_shape

self.loss_critic_real = {}
self.loss_critic_fake = {}
self.loss_critic = {}
self.loss_generator = {}

self.n_critic = 5
self.critic = self.build_critic()
self.critic.trainable = False

self.optimizer_critic = RMSprop(5e-5)

self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss,
optimizer =  RMSprop(5e-5))
self.critic.trainable = True

def wasserstein_loss(self, y_true, y_pred):

w_loss = -tf.reduce_mean(y_true*y_pred)

return w_loss

def build_generator(self):

DIM = 128
model = tf.keras.Sequential(name='Generator')

return model

def build_critic(self):

DIM = 128
model = tf.keras.Sequential(name='critics')

return model

def train_critic(self, real_images, batch_size):

real_labels = tf.ones(batch_size)
fake_labels = -tf.ones(batch_size)

g_input = tf.random.normal((batch_size, self.z_dim))
fake_images = self.generator.predict(g_input)

pred_fake = self.critic(fake_images)
pred_real = self.critic(real_images)

loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
loss_real = self.wasserstein_loss(real_labels, pred_real)

total_loss = loss_fake + loss_real

for layer in self.critic.layers:
weights = layer.get_weights()
weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
layer.set_weights(weights)

return loss_fake, loss_real

def train(self, data_generator, batch_size, steps, interval=200):

val_g_input = tf.random.normal((batch_size, self.z_dim))
real_labels = tf.ones(batch_size)

for i in range(steps):
for _ in range(self.n_critic):
real_images = next(data_generator)
loss_fake, loss_real = self.train_critic(real_images, batch_size)
critic_loss = loss_fake + loss_real

g_input = tf.random.normal((batch_size, self.z_dim))
g_loss = self.model.train_on_batch(g_input, real_labels)

self.loss_critic_real[i] = loss_real.numpy()
self.loss_critic_fake[i] = loss_fake.numpy()
self.loss_critic[i] = critic_loss.numpy()
self.loss_generator[i] = g_loss

if i%interval == 0:
msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f}  critic_real {:.4f}"\
.format(i, g_loss, critic_loss, loss_fake, loss_real)
print(msg)

fake_images = self.generator.predict(val_g_input)
self.plot_images(fake_images)
self.plot_losses()

def plot_images(self, images):
grid_row = 1
grid_col = 8
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
for row in range(grid_row):
for col in range(grid_col):
if self.input_shape[-1]==1:
axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
else:
axarr[col].imshow(images[col]*0.5+0.5)
axarr[col].axis('off')
plt.show()

def plot_losses(self):
fig, (ax1, ax2) = plt.subplots(2, sharex=True)
fig.set_figwidth(10)
fig.set_figheight(6)
ax1.plot(list(self.loss_critic.values()), label='Critic loss', alpha=0.7)
ax1.set_title("Critic loss")
ax2.plot(list(self.loss_generator.values()), label='Generator loss', alpha=0.7)
ax2.set_title("Generator loss")

plt.xlabel('Steps')
plt.show()

wgan = WGAN(image_shape)
wgan.generator.summary()

wgan.critic.summary()

wgan.train(iter(ds_train), batch_size, 2000, 100)

z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)

wgan.generator.save_weights('./wgan_models/wgan_fashion_minist.weights')

"""
WGAN_GP
"""
class WGAN_GP():
def __init__(self, input_shape):

self.z_dim = 128
self.input_shape = input_shape

self.n_critic = 5
self.penalty_const = 10
self.critic = self.build_critic()
self.critic.trainable = False

self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)

def wasserstein_loss(self, y_true, y_pred):

w_loss = -tf.reduce_mean(y_true*y_pred)

return w_loss

def build_generator(self):

DIM = 128
model = Sequential([
layers.Input(shape=[self.z_dim]),

layers.Dense(4*4*4*DIM),
layers.BatchNormalization(),
layers.ReLU(),
layers.Reshape((4,4,4*DIM)),

layers.UpSampling2D((2,2), interpolation='bilinear'),
layers.BatchNormalization(),
layers.ReLU(),

layers.UpSampling2D((2,2), interpolation='bilinear'),
layers.BatchNormalization(),
layers.ReLU(),

layers.UpSampling2D((2,2), interpolation='bilinear'),
],name='Generator')

return model

def build_critic(self):

DIM = 128
model = Sequential([
layers.Input(shape=self.input_shape),

layers.LeakyReLU(0.2),

layers.LeakyReLU(0.2),

layers.LeakyReLU(0.2),

layers.Flatten(),
layers.Dense(1)
], name='critics')

return model

loss = tf.reduce_sum(loss, axis=np.arange(1, len(loss.shape)))
loss = tf.sqrt(loss)
loss = tf.reduce_mean(tf.square(loss - 1))
loss = self.penalty_const * loss
return loss

def train_critic(self, real_images, batch_size):
real_labels = tf.ones(batch_size)
fake_labels = -tf.ones(batch_size)

g_input = tf.random.normal((batch_size, self.z_dim))
fake_images = self.generator.predict(g_input)

pred_fake = self.critic(fake_images)
pred_real = self.critic(real_images)

loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
loss_real = self.wasserstein_loss(real_labels, pred_real)

epsilon = tf.random.uniform((batch_size, 1, 1, 1))
interpolates = epsilon * real_images + (1-epsilon) * fake_images

critic_interpolates = self.critic(interpolates)

total_loss = loss_fake + loss_real + gradient_penalty

def train(self, data_generator, batch_size, steps, interval=100):
val_g_input = tf.random.normal((batch_size, self.z_dim))
real_labels = tf.ones(batch_size)

for i in range(steps):
for _ in range(self.n_critic):
real_images = next(data_generator)
loss_fake, loss_real, gradient_penalty = self.train_critic(real_images, batch_size)
critic_loss = loss_fake + loss_real + gradient_penalty

g_input = tf.random.normal((batch_size, self.z_dim))
g_loss = self.model.train_on_batch(g_input, real_labels)
if i%interval == 0:
msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f}  critic_real {:.4f} penalty {:.4f}".format(i, g_loss, critic_loss, loss_fake, loss_real, gradient_penalty)
print(msg)

fake_images = self.generator.predict(val_g_input)
self.plot_images(fake_images)

def plot_images(self, images):
grid_row = 1
grid_col = 8
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
for row in range(grid_row):
for col in range(grid_col):
if self.input_shape[-1]==1:
axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
else:
axarr[col].imshow(images[col]*0.5+0.5)
axarr[col].axis('off')
plt.show()

wgan = WGAN_GP(image_shape)
wgan.train(iter(ds_train), batch_size, 5000, 100)

wgan.model.summary()

wgan.critic.summary()

z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)


Original: https://blog.csdn.net/LOVEmy134611/article/details/117230911
Author: 盼小辉丶
Title: 图文详解WGAN及其变体WGAN-GP并利用Tensorflow2实现WGAN与WGAN-GP

(0)