分享

Tensorflow|如何保存或导入训练好的模型

 LibraryPKU 2018-08-31




在深度学习实践中,我们通常要先搭建好模型如何经过长时间的训练才能使用。那么,对于训练好的模型,我们自然想把它保存起来以便调用。曾经,我写过一个比较大的深度学习模型,由于每次迭代需要长达1-2个小时,并且用的是nohup在服务器后台上运行,但是鉴于nohup有时会不稳定,因此我的后台程序随时可能会被kill掉,于是我就写了一个定时的模型保存程序,即每隔一定的时间就让其自动保存模型到磁盘文件中去,这样就可以保证即使程序遭到了不可抗拒的终止时,也不会落得前功尽弃的后果。由此可以看出,模型的自动保存是十分重要的。


在Tensorflow中,保存模型最简单的方法是使用tf.train.Saver对象,当我们构造了一个Saver对象以后,调用该对象的save方法即可将我们指定会话中的Tensorflow Graph模型保存到磁盘文件中去;而另一方面,我们可以调用对象的restore方法从磁盘中读取Tensorflow Graph模型。

例如,下面是一个保存模型的示例用法:

l# Create some variables.

v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.

saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the# variables to disk.

with tf.Session() as sess:    sess.run(init_op)  
   # Do some work with the model.    ..  
   # Save the variables to disk.    save_path = saver.save(sess, '/tmp/model.ckpt')    print('Model saved in file: %s' % save_path)


下面是导入模型的方法:

# Create some variables.

v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')...
# Add ops to save and restore all the variables.

saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.

with tf.Session() as sess:  
   # Restore variables from disk.    saver.restore(sess, '/tmp/model.ckpt')    print('Model restored.')  
   # Do some work with the model    ...


除此之外,Saver对象还可以自定义保存变量,即指定保存Graph中的某些变量。有了tf.train.Saver对象,再也不用担心训练好的模型丢失了!


题图:梵高《罗纳河上的星空》。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多