重磅干货,第一时间送达 
又好久没有继续写了,这个是我写的第21篇文章,我还在继续坚持写下去,虽然经常各种拖延症,但是我还记得,一直没有敢忘记!今天给大家分享一下Pytorch生成对抗网络代码实现。 Ian J. Goodfellow在2014年提出生成对抗网络,从此打开了深度学习中另外一个重要分支,让生成对抗网络(GAN)成为与卷积神经网络(CNN)、循环神经网络(RNN/LSTM)可以并驾齐驱的分支领域。今天GAN仍然是计算机视觉领域研究热点之一,每年还有大量相关的论文产生,GAN已经被用在视觉任务的很多方面,主要包括: 图像合成与数据增广 图像翻译与变换 缺陷检测 图像去噪与重建 图像分割
但是GAN最基本的核心思想还是2014年Ian J. Goodfellow在论文中提到的两个基本的模型分别是:生成器与判别器

生成器(G): 根据输入噪声Z生成输出样本G(z) 目标:通过生成样本与目标样本分布一致,成功欺骗鉴别器
判别器(D): 根据输入样本数据来分辨真实样本概率 从数据中学习样本数据的差异性
从a到d,可以看到输入噪声的生成分布越来越接近真实分布X,最终达到一种平衡状态,这种稳定的平衡状态叫纳什均衡,还有一部电影跟这个有关系叫《美丽心灵》。 下面的代码实现了基于Mnist数据集实现判别器与生成器,最终通过生成器可以自动生成手写数字识别的图像,输入的z=100是随机噪声,输出的是784个数据表示28x28大小的手写数字样本,损失主要来自两个部分,生成器生成损失,判别器分别判别真实与虚构样本概率,基于反向传播训练两个网络,设置epoch=100,得到最终的生成器生成结果如下: 
判别器与生成器代码:(后面文字忽略)2004论文中提出,其主要思想可以通过下面一张图像解释: 1transform = tv.transforms.Compose([tv.transforms.ToTensor(), 2 tv.transforms.Normalize((0.5,), (0.5,))]) 3train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform) 4test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform) 5train_dl = DataLoader(train_ts, batch_size=128, shuffle=True, drop_last=False) 6test_dl = DataLoader(test_ts, batch_size=128, shuffle=True, drop_last=False) 7 8 9class Generator(t.nn.Module): 10 def __init__(self, g_input_dim, g_output_dim): 11 super(Generator, self).__init__() 12 self.fc1 = t.nn.Linear(g_input_dim, 256) 13 self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features * 2) 14 self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features * 2) 15 self.fc4 = t.nn.Linear(self.fc3.out_features, g_output_dim) 16 17 # forward method 18 def forward(self, x): 19 x = F.leaky_relu(self.fc1(x), 0.2) 20 x = F.leaky_relu(self.fc2(x), 0.2) 21 x = F.leaky_relu(self.fc3(x), 0.2) 22 return t.tanh(self.fc4(x)) 23 24 25class Discriminator(t.nn.Module): 26 def __init__(self, d_input_dim): 27 super(Discriminator, self).__init__() 28 self.fc1 = t.nn.Linear(d_input_dim, 1024) 29 self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features // 2) 30 self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features // 2) 31 self.fc4 = t.nn.Linear(self.fc3.out_features, 1) 32 33 # forward method 34 def forward(self, x): 35 x = F.leaky_relu(self.fc1(x), 0.2) 36 x = F.dropout(x, 0.3) 37 x = F.leaky_relu(self.fc2(x), 0.2) 38 x = F.dropout(x, 0.3) 39 x = F.leaky_relu(self.fc3(x), 0.2) 40 x = F.dropout(x, 0.3) 41 return t.sigmoid(self.fc4(x))
分别定义生成网络训练与鉴别网络的训练方法,然后开始训练即可,代码实现如下: 1# 生成者与判别者 2bs = 128 3z_dim = 100 4mnist_dim = 784 5# loss 6criterion = t.nn.BCELoss() 7 8# optimizer 9device = 'cuda' 10gnet = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device) 11dnet = Discriminator(mnist_dim).to(device) 12lr = 0.0002 13G_optimizer = t.optim.Adam(gnet.parameters(), lr=lr) 14D_optimizer = t.optim.Adam(dnet.parameters(), lr=lr) 15 16 17def D_train(x): 18 # =======================Train the discriminator=======================# 19 dnet.zero_grad() 20 21 # train discriminator on real 22 x_real, y_real = x.view(-1, mnist_dim), t.ones(bs, 1) 23 x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device)) 24 25 D_output = dnet(x_real) 26 D_real_loss = criterion(D_output, y_real) 27 28 # train discriminator on facke 29 z = Variable(t.randn(bs, z_dim).to(device)) 30 x_fake, y_fake = gnet(z), Variable(t.zeros(bs, 1).to(device)) 31 32 D_output = dnet(x_fake) 33 D_fake_loss = criterion(D_output, y_fake) 34 35 # gradient backprop & optimize ONLY D's parameters 36 D_loss = D_real_loss + D_fake_loss 37 D_loss.backward() 38 D_optimizer.step() 39 40 return D_loss.data.item() 41 42 43def G_train(x): 44 # =======================Train the generator=======================# 45 gnet.zero_grad() 46 47 z = Variable(t.randn(bs, z_dim).to(device)) 48 y = Variable(t.ones(bs, 1).to(device)) 49 50 G_output = gnet(z) 51 D_output = dnet(G_output) 52 G_loss = criterion(D_output, y) 53 54 # gradient backprop & optimize ONLY G's parameters 55 G_loss.backward() 56 G_optimizer.step() 57 58 return G_loss.data.item() 59 60 61n_epoch = 100 62for epoch in range(1, n_epoch+1): 63 D_losses, G_losses = [], [] 64 for batch_idx, (x, _) in enumerate(train_dl): 65 bs_, _,_,_ = x.size() 66 bs = bs_ 67 D_losses.append(D_train(x)) 68 G_losses.append(G_train(x)) 69 70 print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % ( 71 (epoch), n_epoch, t.mean(t.FloatTensor(D_losses)), t.mean(t.FloatTensor(G_losses))))
下载1:OpenCV-Contrib扩展模块中文版教程
|