在深度学习实践中,我们通常要先搭建好模型如何经过长时间的训练才能使用。那么,对于训练好的模型,我们自然想把它保存起来以便调用。曾经,我写过一个比较大的深度学习模型,由于每次迭代需要长达1-2个小时,并且用的是nohup在服务器后台上运行,但是鉴于nohup有时会不稳定,因此我的后台程序随时可能会被kill掉,于是我就写了一个定时的模型保存程序,即每隔一定的时间就让其自动保存模型到磁盘文件中去,这样就可以保证即使程序遭到了不可抗拒的终止时,也不会落得前功尽弃的后果。由此可以看出,模型的自动保存是十分重要的。
在Tensorflow中,保存模型最简单的方法是使用tf.train.Saver对象,当我们构造了一个Saver对象以后,调用该对象的save方法即可将我们指定会话中的Tensorflow Graph模型保存到磁盘文件中去;而另一方面,我们可以调用对象的restore方法从磁盘中读取Tensorflow Graph模型。 例如,下面是一个保存模型的示例用法: l# Create some variables.
v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')... # Add an op to initialize the variables. init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the# variables to disk.
with tf.Session() as sess: sess.run(init_op) # Do some work with the model. .. # Save the variables to disk. save_path = saver.save(sess, '/tmp/model.ckpt') print('Model saved in file: %s' % save_path)
下面是导入模型的方法: # Create some variables.
v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')... # Add ops to save and restore all the variables.
saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.
with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, '/tmp/model.ckpt') print('Model restored.') # Do some work with the model ...
除此之外,Saver对象还可以自定义保存变量,即指定保存Graph中的某些变量。有了tf.train.Saver对象,再也不用担心训练好的模型丢失了!
题图:梵高《罗纳河上的星空》。
|