分享

Tensorflow 如何存取网络模型

 雪柳花明 2017-03-13
    当我训练完网络模型之后,会想到如何去保存训练好的weightsbias等网络参数,并在将来进行分类或者识别的任务中重新载入(restore)这个训练好的网络。那么在tensorflow中是如何实现对网络模型的保存的呢?
    在tensorflow中,变量存储在二进制文件中,主要包含从变量名到tensor值的映射关系。当创建一个Saver对象时,可以选择性地为检查点文件中的变量设置变量名。
    具体的,首先,给变量赋值,不过要在其后加上参数name=“”,注意,这里的name即要保存到网络模型的变量名称,未来在进行网络模型的载入时需要通过该变量值进行数据读取,类似字典的感觉。
v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2”)
    之后,创建一个saver对象,来进行保存,同时不要忘记设定保存的路径。
saver = tf.train.Saver()save_path = saver.save(sess, './MNISTmodel/model.ckpt')print ('Model saved in file: ', save_path)
    模型保存好之后,在需要再次使用这个模型时,同样需要再创建一个saver对象。不要忘记,要将模型中之前保存好的变量名称再赋给需要载入的模型,即
v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name=“v2”)
不过此时不需要对这些变量进行初始化了
saver = tf.train.Saver()......with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, './MNISTmodel/model.ckpt') print 'Model restored.'
    这样就可以直接恢复之前训练好的模型了。经过我的验证,准确度与之前训练好的时刻准确度一致。证明网络模型确实被成功恢复了。
    模型的保存不仅为了将来再次使用它进行分类等任务,也可以用来做fine-tuning。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多