机器学习-生成对抗网络WGAN-GP实战(四-1)

上一篇文章简单介绍了WGAN-GP的原理,本文来实现WGAN-GP的实战。

还是建议大家先读机器学习-生成对抗网络变种(三)

之前的博客写了DCGAN的实战代码,实际上在生成器和判别器网络构建方面都相差不大。

大家可以参照机器学习-生成对抗网络实战(二-1),进行对照学习。

目录

Part1判别器和生成器网络的设计:

自定义生成器类:

自定义判别器类:

Part1判别器和生成器网络的设计:

自定义生成器类:

class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
        self.fc = layers.Dense(3*3*512)

        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

它的基本功能是使用转置卷积来生成图像,但前向传播略有不同。

[En]

Its essential function is to use transpose convolution to generate images, but forward propagation is slightly different.

    def call(self, inputs, training=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        x = tf.nn.leaky_relu(x)
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        x = tf.tanh(x)

        return x

大家应该能注意到此时网络的激活函数除了最后一层都使用的leaky_relu激活函数,而最后一层使用的是tanh激活函数。这实际上是一系列的训练技巧,并不能从理论层面解释为什么这些激活函数比之前使用的relu效果好,大家记住就OK。

自定义判别器类:

class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        # [b, 64, 64, 3] => [b, 1]
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')

        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        # [b, h, w ,c] => [b, -1]
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

这一块和前面的DCGAN原理基本类似。最后卷积层提取完特征值之后打平输入全连接层,最后输出一个二分结果。

    def call(self, inputs, training=None):

        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        # [b, h, w, c] => [b, -1]
        x = self.flatten(x)
        # [b, -1] => [b, 1]
        logits = self.fc(x)

        return logits

此时使用的激活函数是leaky_relu大家注意区分,最后的二分输出此处不必激活优化,后面会自动优化。

代码来自于《TensorFlow深度学习》-龙龙老师

机器学习-生成对抗网络WGAN-GP实战(四-1)

Original: https://blog.csdn.net/weixin_46737548/article/details/124301168
Author: weixin_46737548
Title: 机器学习-生成对抗网络WGAN-GP实战(四-1)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514038/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球