分享

追根溯源,从第一篇论文开始理解和优化GAN

 taotao_2016 2019-08-26
作者:Mirantha Jayathilaka
编译:ronghuaiyang

导读

充分理解一个概念的最好方法是追根溯源,对于GANs来说,我们就从它的第一篇论文开始吧。

追根溯源,从第一篇论文开始理解和优化GAN

Fig 1. 基于基本GAN算法的人脸伪图像生成的改进训练

自从生成对抗性网络(GANs)的架构首次由Ian Goodfellow引入以来,围绕它的宣传就一直在增长,而且它的许多改进和应用也变得越来越吸引人。但是对于任何想要开始学习GANs的人来说,确定从哪里开始是相当棘手的。不要慌,这篇文章会指引你的。

就像许多事情一样,充分理解一个概念的最好方法是追根溯源。抓住第一原则。对于GANs,这是原始论文—>(https:///abs/1406.2661)。要理解这类论文有两种方法,理论和实践。我通常更喜欢后者,但如果你喜欢深入研究数学,我的好友Sameera发表了一篇从理论上分析整个算法的好文章。同时,这篇文章将以Keras最纯粹的形式展示该算法的一个简单实现。让我们开始吧。

在一个GAN的底层设置中有两个模型,生成器和判别器,生成器不断与判别器竞争,判别器学习区分模型分布(例如生成的假图像)和数据分布(例如真实图像)。这一概念常见于著名的伪造者vs警察场景,生成模型被认为是伪造者生成假现金,而判别器模型是警察试图检测假现金。其思想是,假币制造者和警察之间不断竞争,双方的能力都在不断提高,但假币制造者最终达到了制造的假币与真币难以区分的阶段。现在我们把它写进代码。

本文提供的示例脚本用于生成人脸的假图像。我们试图用该算法实现的最终结果如图1所示。

构建生成模型

生成模型应该接收一些噪声并输出理想的图像。这里我们使用Keras顺序模型以及Dense层和Batch Normalization层。使用的激活函数是Leaky Relu。参考下面的代码片段。生成模型可分为几个块,一个块由Dense层->激活->Batch Normalization组成。添加三个这样的块,最后一个块将像素转换为我们期望输出的图像的形状。模型的输入是一个形状为(100,)的噪声向量,最后返回这个模型。请注意,随着模型的发展,每个dense层中的节点数量是如何增长的。

def build_generator(self): noise_shape = (100,) model = Sequential() model.add(Dense(256, input_shape=noise_shape)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() noise = Input(shape=noise_shape) img = model(noise) return Model(noise, img)

构建判别模型

鉴别器接收图像的输入,将其扁平化,并将其通过两个dense->激活块,最终输出1到0之间的标量。输出1表示图像为真实的,否则为0。就这么简单。参考下面的代码。

 def build_discriminator(self):  img_shape = (self.img_rows, self.img_cols, self.channels)  model = Sequential()  model.add(Flatten(input_shape=img_shape)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary()  img = Input(shape=img_shape) validity = model(img)  return Model(img, validity)

注:-你可以修改这些模型,拥有更多的块,更多的batch norm层,不同的激活等。根据本例,这个模型足以理解GANs.背后的概念

寻找损失和训练

为了训练这两个模型,我们计算了三个损失,在这个例子中都使用了二元交叉熵。

首先是判别器。它在训练时走两条路,如下面的代码所示。首先为真实图像输出1(数组' img '),然后为生成的图像输出0(数组' gen_img ')。随着训练的进行,判别器在这一任务上得到了改进。但是我们的最终目标是在理论上达到这样一个点,即判别器对两种输入都输出0.5(即不确定是假的还是真的)。

d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1))) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))

接下来是训练生成器,这是比较棘手的一点。为了做到这一点,我们首先为给定的生成器的输出制定了一个组合模型判别器。记住!理想情况下,我们希望这是1,这意味着判别器识别一个假生成的图像为真实图像的。我们将组合模型的输出与1进行对比。参见下面的代码。

 z = Input(shape=(100,)) img = self.generator(z)  valid = self.discriminator(img)  self.combined = Model(z, valid) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)  g_loss = self.combined.train_on_batch(noise, valid_y)

这基本上就是代码的要点,以便简单地理解GANs的工作原理。

完整的代码可以在我的GitHub上找到这里 。你可以参考所有的附加代码,这些代码用于导入RGB图像、初始化模型和在代码中记录结果。还要注意,在训练期间,为了能够在CPU上运行,将mini batch图像设置为Hi32图像。此外,示例中使用的实际图像是来自CelebA数据集的5000张图像。这是一个开源的数据集,我已经将它上传到我的Floydhub以便于下载,你可以在这里找到。

所以在32个batches,5000个epoch的训练中,我测试了三种优化算法。使用Keras,这个过程就像导入和替换优化器函数的名称一样简单。所有Keras内置的优化器都可以在这里找到。

在每个实例中绘制损失用来理解模型的行为。

  1. 使用SGD(随机梯度下降优化器)。输出和损失变化分别如图2和图3所示。

追根溯源,从第一篇论文开始理解和优化GAN

Fig 2. 使用SGD作为GAN的优化器

追根溯源,从第一篇论文开始理解和优化GAN

Fig 3. 图中显示了使用SGD训练GAN时损失的变化

注释—虽然收敛过程的噪声很多,但我们可以看到,随着时间的推移,生成器的损失在减小,这意味着判别器倾向于将假图像检测为真实的。

  1. 使用RMSProp优化器,输出和损失变化分别如图4和图5所示。

追根溯源,从第一篇论文开始理解和优化GAN

Fig 4. 使用RMSProp作为优化器的GAN的输出

损失:

追根溯源,从第一篇论文开始理解和优化GAN

Fig 5. 显示使用RMSProp训练GAN时损失变化的图

注释—这里我们也看到生成器的损失在减少,这是件好事。令人惊讶的是,实际图像上的判别器损失增加了,这很有趣。

  1. 用 Adam优化器。输出和损失变化分别如图6和图7所示。

追根溯源,从第一篇论文开始理解和优化GAN

Fig 6. 使用Adam优化器的GAN的输出

追根溯源,从第一篇论文开始理解和优化GAN

Fig 7. 图中显示了用Adam训练GAN时损失的变化

注释—Adam优化器生成了迄今为止最好看的结果。注意判别器在假图像上的损失如何保持较大的值,这意味着判别器倾向于将假图像检测为真实图像。

备注

我希望这篇文章从实践的角度对GANs的内部工作原理进行了基本的了解,了解如何改进基本模型。开源社区中有许多针对不同应用程序的GANs的实现,对第一原则有良好的理解将极大地帮助你理解这些改进。此外,GANs是相对较新的深度学习,有许多研究途径开放给感兴趣的人。

英文原文:https:///understanding-and-optimizing-gans-going-back-to-first-principles-e5df8835ae18

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多