分享

手把手教你写一个生成对抗网络

 imelee 2017-04-15

今天我们接着上一讲“#9 生成对抗网络101 终极入门与通俗解析”, 手把手教你写一个生成对抗网络。参考代码是:https://github.com/AYLIEN/gan-intro

关键python库: TensorFlow, numpy, matplotlib, scipy

我们上次讲过,生成对抗网络同时训练两个模型, 叫做生成器判断器. 生成器竭尽全力模仿真实分布生成数据; 判断器竭尽全力区分出真实样本和生成器生成的模仿样本. 直到判断器无法区分出真实样本和模仿样本为止.

out
来自:http://blog./introduction-generative-adversarial-networks-code-tensorflow/

上图是一个生成对抗网络的训练过程,我们所要讲解的代码就是要实现这样的训练过程。
其中, 绿色线的分布是一个高斯分布(真实分布),期望和方差都是固定值,所以分布稳定。红色线的分布是生成器分布,他在训练过程中与判断器对抗,不断改变分布模仿绿色线高斯分布. 整个过程不断模仿绿色线蓝色线的分布就是判断器,约定为, 概率密度越高, 认为真实数据的可能性越大. 可以看到蓝线在真实数据期望4的地方,蓝色线概率密度越来越小, 即, 判断器难区分出生成器和判断器.

接下来我们来啃一下David 9看过最复杂的TensorFlow源码逻辑:

首先看总体逻辑:

来自: https://ishmaelbelghazi./ALI
来自: https://ishmaelbelghazi./ALI

正像之前所说, 有两个神经模型在交替训练. 生成模型输入噪声分布, 把噪声分布映射成很像真实分布的分布, 生成仿造的样本. 判断模型输入生成模型的仿造样本, 区分这个样本是不真实样本. 如果最后区分不出, 恭喜你, 模型训练的很不错.

我们的生成器模型映射作用很像下图:

screenshot-from-2016-11-12-171611

Z是一个平均分布加了点噪声而已.  X是真实分布. 我们希望这个神经网络输入相同间隔的输入值 , 输出就能告诉我们这个值的概率密度(pdf)多大? 很显然-1这里pdf应该比较大.

Z如何写代码? 很简单:

  1. class GeneratorDistribution(object):
  2. def __init__(self, range):
  3. self.range = range
  4. def sample(self, N):
  5. return np.linspace(-self.range, self.range, N) + \
  6. np.random.random(N) * 0.01

查不多采样值像下图:

screenshot-from-2016-11-12-172319

只是多了一点点噪声而已.

生成器用一层线性, 加一层非线性, 最后加一层线性的神经网络.

判断器需要强大一些, 用三层线神经网络去做:

  1. def discriminator(input, hidden_size):
  2. h0 = tf.tanh(linear(input, hidden_size * 2, 'd0'))
  3. h1 = tf.tanh(linear(h0, hidden_size * 2, 'd1'))
  4. h2 = tf.tanh(linear(h1, hidden_size * 2, 'd2'))
  5. h3 = tf.sigmoid(linear(h2, 1, 'd3'))
  6. return h3

然后, 我们构造TensorFlow图, 还有判断器和生成器的损失函数:

  1. with tf.variable_scope('G'):
  2. z = tf.placeholder(tf.float32, shape=(None, 1))
  3. G = generator(z, hidden_size)
  4. with tf.variable_scope('D') as scope:
  5. x = tf.placeholder(tf.float32, shape=(None, 1))
  6. D1 = discriminator(x, hidden_size)
  7. scope.reuse_variables()
  8. D2 = discriminator(G, hidden_size)
  9. loss_d = tf.reduce_mean(-tf.log(D1) - tf.log(1 - D2))
  10. loss_g = tf.reduce_mean(-tf.log(D2))

最神奇的应该是这句:

  1. loss_d = tf.reduce_mean(-tf.log(D1) - tf.log(1 - D2))

我们有同样的一个判断模型, D1和D2的区别仅仅是D1的输入是真实数据, D2的输入是生成器的伪造数据. 注意, 代码中判断模型的输出是“认为一个样本在真实分布中的可能性”. 所以优化时目标是, D1的输出要尽量大, D2的输出要尽量小.

此外, 优化生成器的时候, 我们要欺骗判断器, 让D2的输出尽量大:

  1. loss_g = tf.reduce_mean(-tf.log(D2))

最难的难点, David 9 给大家已经讲解了. 如何写优化器(optimizer)和训练过程, 请大家参考源代码~

源代码:

  1. '''
  2. An example of distribution approximation using Generative Adversarial Networks in TensorFlow.
  3. Based on the blog post by Eric Jang: http://blog./2016/06/generative-adversarial-nets-in.html,
  4. and of course the original GAN paper by Ian Goodfellow et. al.: https:///abs/1406.2661.
  5. The minibatch discrimination technique is taken from Tim Salimans et. al.: https:///abs/1606.03498.
  6. '''
  7. from __future__ import absolute_import
  8. from __future__ import print_function
  9. from __future__ import unicode_literals
  10. from __future__ import division
  11. import argparse
  12. import numpy as np
  13. from scipy.stats import norm
  14. import tensorflow as tf
  15. import matplotlib.pyplot as plt
  16. from matplotlib import animation
  17. import seaborn as sns
  18. sns.set(color_codes=True)
  19. seed = 42
  20. np.random.seed(seed)
  21. tf.set_random_seed(seed)
  22. class DataDistribution(object):
  23. def __init__(self):
  24. self.mu = 4
  25. self.sigma = 0.5
  26. def sample(self, N):
  27. samples = np.random.normal(self.mu, self.sigma, N)
  28. samples.sort()
  29. return samples
  30. class GeneratorDistribution(object):
  31. def __init__(self, range):
  32. self.range = range
  33. def sample(self, N):
  34. return np.linspace(-self.range, self.range, N) + \
  35. np.random.random(N) * 0.01
  36. def linear(input, output_dim, scope=None, stddev=1.0):
  37. norm = tf.random_normal_initializer(stddev=stddev)
  38. const = tf.constant_initializer(0.0)
  39. with tf.variable_scope(scope or 'linear'):
  40. w = tf.get_variable('w', [input.get_shape()[1], output_dim], initializer=norm)
  41. b = tf.get_variable('b', [output_dim], initializer=const)
  42. return tf.matmul(input, w) + b
  43. def generator(input, h_dim):
  44. h0 = tf.nn.softplus(linear(input, h_dim, 'g0'))
  45. h1 = linear(h0, 1, 'g1')
  46. return h1
  47. def discriminator(input, h_dim, minibatch_layer=True):
  48. h0 = tf.tanh(linear(input, h_dim * 2, 'd0'))
  49. h1 = tf.tanh(linear(h0, h_dim * 2, 'd1'))
  50. # without the minibatch layer, the discriminator needs an additional layer
  51. # to have enough capacity to separate the two distributions correctly
  52. if minibatch_layer:
  53. h2 = minibatch(h1)
  54. else:
  55. h2 = tf.tanh(linear(h1, h_dim * 2, scope='d2'))
  56. h3 = tf.sigmoid(linear(h2, 1, scope='d3'))
  57. return h3
  58. def minibatch(input, num_kernels=5, kernel_dim=3):
  59. x = linear(input, num_kernels * kernel_dim, scope='minibatch', stddev=0.02)
  60. activation = tf.reshape(x, (-1, num_kernels, kernel_dim))
  61. diffs = tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)
  62. eps = tf.expand_dims(np.eye(int(input.get_shape()[0]), dtype=np.float32), 1)
  63. abs_diffs = tf.reduce_sum(tf.abs(diffs), 2) + eps
  64. minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2)
  65. return tf.concat(1, [input, minibatch_features])
  66. def optimizer(loss, var_list):
  67. initial_learning_rate = 0.005
  68. decay = 0.95
  69. num_decay_steps = 150
  70. batch = tf.Variable(0)
  71. learning_rate = tf.train.exponential_decay(
  72. initial_learning_rate,
  73. batch,
  74. num_decay_steps,
  75. decay,
  76. staircase=True
  77. )
  78. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(
  79. loss,
  80. global_step=batch,
  81. var_list=var_list
  82. )
  83. return optimizer
  84. class GAN(object):
  85. def __init__(self, data, gen, num_steps, batch_size, minibatch, log_every, anim_path):
  86. self.data = data
  87. self.gen = gen
  88. self.num_steps = num_steps
  89. self.batch_size = batch_size
  90. self.minibatch = minibatch
  91. self.log_every = log_every
  92. self.mlp_hidden_size = 4
  93. self.anim_path = anim_path
  94. self.anim_frames = []
  95. self._create_model()
  96. def _create_model(self):
  97. # In order to make sure that the discriminator is providing useful gradient
  98. # information to the generator from the start, we're going to pretrain the
  99. # discriminator using a maximum likelihood objective. We define the network
  100. # for this pretraining step scoped as D_pre.
  101. with tf.variable_scope('D_pre'):
  102. self.pre_input = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
  103. self.pre_labels = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
  104. D_pre = discriminator(self.pre_input, self.mlp_hidden_size, self.minibatch)
  105. self.pre_loss = tf.reduce_mean(tf.square(D_pre - self.pre_labels))
  106. self.pre_opt = optimizer(self.pre_loss, None)
  107. # This defines the generator network - it takes samples from a noise
  108. # distribution as input, and passes them through an MLP.
  109. with tf.variable_scope('G'):
  110. self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
  111. self.G = generator(self.z, self.mlp_hidden_size)
  112. # The discriminator tries to tell the difference between samples from the
  113. # true data distribution (self.x) and the generated samples (self.z).
  114. #
  115. # Here we create two copies of the discriminator network (that share parameters),
  116. # as you cannot use the same network with different inputs in TensorFlow.
  117. with tf.variable_scope('D') as scope:
  118. self.x = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
  119. self.D1 = discriminator(self.x, self.mlp_hidden_size, self.minibatch)
  120. scope.reuse_variables()
  121. self.D2 = discriminator(self.G, self.mlp_hidden_size, self.minibatch)
  122. # Define the loss for discriminator and generator networks (see the original
  123. # paper for details), and create optimizers for both
  124. #self.pre_loss = tf.reduce_mean(tf.square(D_pre - self.pre_labels))
  125. self.loss_d = tf.reduce_mean(-tf.log(self.D1) - tf.log(1 - self.D2))
  126. self.loss_g = tf.reduce_mean(-tf.log(self.D2))
  127. vars = tf.trainable_variables()
  128. self.d_pre_params = [v for v in vars if v.name.startswith('D_pre/')]
  129. self.d_params = [v for v in vars if v.name.startswith('D/')]
  130. self.g_params = [v for v in vars if v.name.startswith('G/')]
  131. #self.pre_opt = optimizer(self.pre_loss, self.d_pre_params)
  132. self.opt_d = optimizer(self.loss_d, self.d_params)
  133. self.opt_g = optimizer(self.loss_g, self.g_params)
  134. def train(self):
  135. with tf.Session() as session:
  136. tf.initialize_all_variables().run()
  137. # pretraining discriminator
  138. num_pretrain_steps = 1000
  139. for step in xrange(num_pretrain_steps):
  140. d = (np.random.random(self.batch_size) - 0.5) * 10.0
  141. labels = norm.pdf(d, loc=self.data.mu, scale=self.data.sigma)
  142. pretrain_loss, _ = session.run([self.pre_loss, self.pre_opt], {
  143. self.pre_input: np.reshape(d, (self.batch_size, 1)),
  144. self.pre_labels: np.reshape(labels, (self.batch_size, 1))
  145. })
  146. self.weightsD = session.run(self.d_pre_params)
  147. # copy weights from pre-training over to new D network
  148. for i, v in enumerate(self.d_params):
  149. session.run(v.assign(self.weightsD[i]))
  150. for step in xrange(self.num_steps):
  151. # update discriminator
  152. x = self.data.sample(self.batch_size)
  153. z = self.gen.sample(self.batch_size)
  154. loss_d, _ = session.run([self.loss_d, self.opt_d], {
  155. self.x: np.reshape(x, (self.batch_size, 1)),
  156. self.z: np.reshape(z, (self.batch_size, 1))
  157. })
  158. # update generator
  159. z = self.gen.sample(self.batch_size)
  160. loss_g, _ = session.run([self.loss_g, self.opt_g], {
  161. self.z: np.reshape(z, (self.batch_size, 1))
  162. })
  163. if step % self.log_every == 0:
  164. #pass
  165. print('{}: {}\t{}'.format(step, loss_d, loss_g))
  166. if self.anim_path:
  167. self.anim_frames.append(self._samples(session))
  168. if self.anim_path:
  169. self._save_animation()
  170. else:
  171. self._plot_distributions(session)
  172. def _samples(self, session, num_points=10000, num_bins=100):
  173. '''
  174. Return a tuple (db, pd, pg), where db is the current decision
  175. boundary, pd is a histogram of samples from the data distribution,
  176. and pg is a histogram of generated samples.
  177. '''
  178. xs = np.linspace(-self.gen.range, self.gen.range, num_points)
  179. bins = np.linspace(-self.gen.range, self.gen.range, num_bins)
  180. # decision boundary
  181. db = np.zeros((num_points, 1))
  182. for i in range(num_points // self.batch_size):
  183. db[self.batch_size * i:self.batch_size * (i + 1)] = session.run(self.D1, {
  184. self.x: np.reshape(
  185. xs[self.batch_size * i:self.batch_size * (i + 1)],
  186. (self.batch_size, 1)
  187. )
  188. })
  189. # data distribution
  190. d = self.data.sample(num_points)
  191. pd, _ = np.histogram(d, bins=bins, density=True)
  192. # generated samples
  193. zs = np.linspace(-self.gen.range, self.gen.range, num_points)
  194. g = np.zeros((num_points, 1))
  195. for i in range(num_points // self.batch_size):
  196. g[self.batch_size * i:self.batch_size * (i + 1)] = session.run(self.G, {
  197. self.z: np.reshape(
  198. zs[self.batch_size * i:self.batch_size * (i + 1)],
  199. (self.batch_size, 1)
  200. )
  201. })
  202. pg, _ = np.histogram(g, bins=bins, density=True)
  203. return db, pd, pg
  204. def _plot_distributions(self, session):
  205. db, pd, pg = self._samples(session)
  206. db_x = np.linspace(-self.gen.range, self.gen.range, len(db))
  207. p_x = np.linspace(-self.gen.range, self.gen.range, len(pd))
  208. f, ax = plt.subplots(1)
  209. ax.plot(db_x, db, label='decision boundary')
  210. ax.set_ylim(0, 1)
  211. plt.plot(p_x, pd, label='real data')
  212. plt.plot(p_x, pg, label='generated data')
  213. plt.title('1D Generative Adversarial Network')
  214. plt.xlabel('Data values')
  215. plt.ylabel('Probability density')
  216. plt.legend()
  217. plt.show()
  218. def _save_animation(self):
  219. f, ax = plt.subplots(figsize=(6, 4))
  220. f.suptitle('1D Generative Adversarial Network', fontsize=15)
  221. plt.xlabel('Data values')
  222. plt.ylabel('Probability density')
  223. ax.set_xlim(-6, 6)
  224. ax.set_ylim(0, 1.4)
  225. line_db, = ax.plot([], [], label='decision boundary')
  226. line_pd, = ax.plot([], [], label='real data')
  227. line_pg, = ax.plot([], [], label='generated data')
  228. frame_number = ax.text(
  229. 0.02,
  230. 0.95,
  231. '',
  232. horizontalalignment='left',
  233. verticalalignment='top',
  234. transform=ax.transAxes
  235. )
  236. ax.legend()
  237. db, pd, _ = self.anim_frames[0]
  238. db_x = np.linspace(-self.gen.range, self.gen.range, len(db))
  239. p_x = np.linspace(-self.gen.range, self.gen.range, len(pd))
  240. def init():
  241. line_db.set_data([], [])
  242. line_pd.set_data([], [])
  243. line_pg.set_data([], [])
  244. frame_number.set_text('')
  245. return (line_db, line_pd, line_pg, frame_number)
  246. def animate(i):
  247. frame_number.set_text(
  248. 'Frame: {}/{}'.format(i, len(self.anim_frames))
  249. )
  250. db, pd, pg = self.anim_frames[i]
  251. line_db.set_data(db_x, db)
  252. line_pd.set_data(p_x, pd)
  253. line_pg.set_data(p_x, pg)
  254. return (line_db, line_pd, line_pg, frame_number)
  255. anim = animation.FuncAnimation(
  256. f,
  257. animate,
  258. init_func=init,
  259. frames=len(self.anim_frames),
  260. blit=True
  261. )
  262. anim.save(self.anim_path, fps=30, extra_args=['-vcodec', 'libx264'])
  263. def main(args):
  264. model = GAN(
  265. DataDistribution(),
  266. GeneratorDistribution(range=8),
  267. args.num_steps,
  268. args.batch_size,
  269. args.minibatch,
  270. args.log_every,
  271. args.anim
  272. )
  273. model.train()
  274. def parse_args():
  275. parser = argparse.ArgumentParser()
  276. parser.add_argument('--num-steps', type=int, default=1200,
  277. help='the number of training steps to take')
  278. parser.add_argument('--batch-size', type=int, default=12,
  279. help='the batch size')
  280. parser.add_argument('--minibatch', type=bool, default=False,
  281. help='use minibatch discrimination')
  282. parser.add_argument('--log-every', type=int, default=10,
  283. help='print loss after this many steps')
  284. parser.add_argument('--anim', type=str, default=None,
  285. help='name of the output animation file (default: none)')
  286. return parser.parse_args()
  287. if __name__ == '__main__':
  288. '''
  289. data_sample = DataDistribution()
  290. d = data_sample.sample(10)
  291. print(d)
  292. '''
  293. main(parse_args())

 

参考文献:

  1. An introduction to Generative Adversarial Networks (with code in TensorFlow)
  2. Generative Adversarial Nets in TensorFlow (Part I)
赶快成为第一个赞的人吧
The following two tabs change content below.
David 9
My Twitter profileMy Facebook profileMy Google+ profile

David 9

微博: http://weibo.com/herewearenow 邮箱:yanchao727@gmail.com 微信: david9ml

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约