生成拮抗网络(GAN)是由Goodfellow等人于2014年开发的。
事实上,它彻底改变了计算机视觉中的图像生成领域:没有人能相信这些惊人而生动的图像实际上是纯粹由机器生成的。
事实上,人们曾经认为生成的任务是不可能的,并对GAN的力量感到震惊,因为传统上,没有事实来比较我们生成的图像。
介绍了GAN创建背后的简单直觉,然后介绍了PyTorch实现的卷积GAN及其训练过程。
#甘背后的直觉
与传统的分类方法不同,我们的网络预测可以直接与事实的正确答案进行比较,但生成图像的“正确性”很难定义和衡量。古德费勒
其他人在他们的原始论文《生成对抗网络》 (_生成性对抗网络_
)提出了一个有趣的想法:使用训练好的分类器来区分生成的图像和实际图像。如果有这样的分类器,我们可以创建和训练一个生成器网络,直到它的输出图像可以完全欺骗分类器。
氮化镓管道
GAN就是这个过程的产物:它包括一个根据给定数据集生成图像的生成器,以及一个区分图像是真实还是生成的鉴别器(分类器)。GAN详细配管见图1。
#损失函数
生成器和鉴别器都很难优化,因为可以想象,这两个网络的目标是完全相反的:生成器想创造尽可能真实的东西,而鉴别器想区分生成的材料。
为了说明这一点,假设D(x)是鉴别器的输出,即x是真实图像的概率,G(z)
是我们发电机的输出。鉴别器类似于二进制分类器,因此鉴别器的目标是最大化函数:
本质上是二元交叉熵损失,开头没有负号。另一方面,生成器的目标是最小化鉴别器做出正确判断的机会,因此它的目标是最小化函数。因此,最终的损失函数将是两个分类器(minimax
游戏),如下所示:
理论上,这将收敛到鉴别器,预测所有事件的概率为0.5。
但是在实践中,minimax游戏往往会导致网络无法收敛,因此仔细调整训练过程是非常重要的。像学习率这样的超参数对于训练GAN显然更重要:稍有变化就会导致
不管输入噪声如何,氮化镓都能产生输出。
#计算环境
#图书馆
我们通过PyTorch库(包括torchvision)构建整个程序。GAN的生成结果用Matplotlib可视化。
图书馆的。以下代码导入所有库:
importGAN.py
''导入必要的库以创建生成性对抗网络该代码主要使用PyTorch库开发' ' '导入时间导入torc himport torch.nn作为n导入torch.optim作为optimfrom torch.utils.data导入DataLoaderfrom torchvision导入数据集来自torch vision . transforms imp ort transforms from模型导入鉴别器,generatorimport numpy作为npimport matplotlib.pyplot作为plt
#数据集
在GAN训练中,数据集是一个重要的方面。图像的非结构化性质意味着任何给定的类别(如狗、猫或手写数字)都可以有一个可能的数据分布,而这个分布最终是GAN。
生成内容的基础。
为了演示,本文将使用最简单的MNIST数据集,其中包含从0到9的60,000个手写数字图像。事实上,像MNIST这样的非结构化数据集可以用于
在Graviti上找到的。这是一家年轻的创业公司。他们希望通过非结构化数据集帮助社区。他们的平台上有一些最好的公共非结构化数据集,包括MNIST。
#硬件
要求最好的方法是用 GPU 训练神经网络,它可以显著地提高训练速度。但是,如果只有 CPU
可用,你仍然可以测试程序。要使你的程序能够自行确定硬件,你可以使用以下方法:
> torchDevice.py
"""Determine if any GPUs are available"""device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 实施
# 网络架构
由于数字的简单性,这两种架构――判别器和生成器,都是由全连接层构建的。请注意,在某些情况下,全连接的 GAN 也比 DCGAN 略微容易收敛。
以下是两种架构的 PyTorch 实现:
GANArchitecture.py
"""Network ArchitecturesThe following are the discriminator and generator architectures"""class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 1) self.activation = nn.LeakyReLU(0.1) def forward(self, x): x = x.view(-1, 784) x = self.activation(self.fc1(x)) x = self.fc2(x) return nn.Sigmoid()(x)class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.fc1 = nn.Linear(128, 1024) self.fc2 = nn.Linear(1024, 2048) self.fc3 = nn.Linear(2048, 784) self.activation = nn.ReLU() def forward(self, x): x = self.activation(self.fc1(x)) x = self.activation(self.fc2(x)) x = self.fc3(x) x = x.view(-1, 1, 28, 28) return nn.Tanh()(x)
# 训练
在训练 GAN
时,我们优化了判别器的结果,同时也改进了我们的生成器。这样,在每次迭代过程中会有两个相互矛盾的损失来同时优化它们。我们送入生成器的是随机噪声,而生成器理应根据给定噪声的微小差异来生成图像:
trainGAN.py
"""Network training procedureEvery step both the loss for disciminator and generator is updatedDiscriminator aims to classify reals and fakesGenerator aims to generate images as realistic as possible"""for epoch in range(epochs): for idx, (imgs, _) in enumerate(train_loader): idx += 1 # Training the discriminator # Real inputs are actual images of the MNIST dataset # Fake inputs are from the generator # Real inputs should be classified as 1 and fake as 0 real_inputs = imgs.to(device) real_outputs = D(real_inputs) real_label = torch.ones(real_inputs.shape[0], 1).to(device) noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) outputs = torch.cat((real_outputs, fake_outputs), 0) targets = torch.cat((real_label, fake_label), 0) D_loss = loss(outputs, targets) D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # Training the generator # For generator, goal is to make the discriminator believe everything is 1 noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device) G_loss = loss(fake_outputs, fake_targets) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() if idx % 100 == 0 or idx == len(train_loader): print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item())) if (epoch+1) % 10 == 0: torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) print('Model saved.')
# 结果
当 100 个轮数(epoch)之后,我们可以绘制数据集,并看到从随机噪音中生成的数字的结果:
图 2:GAN 生成的结
如上图所示,生成的结果看起来确实相当像真实的结果。鉴于网络非常简单,所以结果看起来确实很有希望!
# 超越单纯的内容创作
GAN 的创造与计算机视觉领域的先前工作如此不同。随后的众多应用使学术界对深度网络的能力感到惊讶。下面将介绍一些令人惊讶的工作。
# CycleGAN
Zhu 等人的 CycleGAN 引入了一种概念,它无需配对样本就可以将图像从 X 域翻译成 Y
域。马被转化为斑马,夏日的阳光被转化为暴风雪,CycleGAN 的结果令人惊讶且准确。
3:Zhu 等人的 CycleGAN 生成的结果。
# GauGAN
Nvidia 利用 GAN 的力量,把简单的绘画,根据画笔的语义,转换成优雅而逼真的照片。尽管训练资源的计算成本很高,但它创造了一个全新的研究和应用领域。
4:GaoGAN 的生成结果。左为原图,右为生成的结果。
# AdvGAN
GAN 还扩展到清理对抗性图像,并将其转化为不会欺骗分类器的干净样本。关于对抗性攻击和防御的更多信息可以在这里到。
# 结语
所以,你已经拥有了它!希望这篇文章对如何构建 GAN 提供了一个概览。完整的实现可以在下面的 Github 资源库中找到:
https://github.com/ttchengab/MnistGAN
作者简介:
Ta-ying Cheng,中国香港人,牛津大学哲学博士新生,爱好 3D 视觉、深度学习。
原文链接:
https://towardsdatascience.com/building-a-gan-with-pytorch-237b4b07ca9a