分享

TensorFlow2.0 | 模型保存与加载

 LibraryPKU 2019-11-13

TensorFlow2.0学习
2/10/2019 晚 19:35

目录

模型保存与加载1 简单训练一个模型2 保存整个模型3 仅保存架构4 仅保存权重5 使用回调函数保存模型6 自定义训练中保存checkpoint

模型保存与加载

Note:

  • tf.keras模型保存为HDF5文件;

  • keras使用了h5py Python包;

  • h5py是keras的依赖项,应默认被安装。

1 简单训练一个模型

import tensorflow as tf
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
%matplotlib inline
(train_image, train_label),(test_image,test_label)=tf.keras.datasets.fashion_mnist.load_data()
train_image = train_image/255
test_image = test_image/255
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(2828)))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.summary()
Model: 'sequential'
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================

flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================

Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['acc'])
model.fit(train_image, train_label, epochs=3)
Train on 60000 samples
Epoch 1/3
60000/60000 [==============================] - 3s 44us/sample - loss: 0.5002 - acc: 0.8237
Epoch 2/3
60000/60000 [==============================] - 2s 29us/sample - loss: 0.3808 - acc: 0.8638
Epoch 3/3
60000/60000 [==============================] - 2s 30us/sample - loss: 0.3409 - acc: 0.8763





<tensorflow.python.keras.callbacks.History at 0x2db2275bb00>
model.evaluate(test_image, test_label, verbose=0)
[0.37253551368713380.8677]

2 保存整个模型

Description:

  • 整个模型可以保存到一个文件中,其中包含模型结构、权重、训练配置(损失函数和优化器)乃至优化器状态(允许准确地从上次训练结束的地方继续训练);

  • 可为模型设置checkpoint,并稍后从完全相同的状态继续训练,而无需访问源代码;

  • 在Keras中保存完全可正常使用的模型非常有用,可以在tensorflow.js中加载它们,然后在Web浏览器中训练和运行;或者使用Tensorflow Lite将其转化为在移动设备上运行。

  • Keras使用HDF5标准提供基本的保存格式。

Note:

  • 不建议使用pickle或cPickle来保存模型;

  • 使用model.save('path/to/my_model.h5')将整个模型保存到单个HDF5文件中。

# 保存模型
model.save('my_model.h5')
# 加载模型
new_model = tf.keras.models.load_model('my_model.h5')
# 查看模型配置
new_model.summary()
Model: 'sequential'
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================

flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================

Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
# 使用新模型评估
new_model.evaluate(test_image, test_label, verbose=0)
[0.37253551368713380.8677]
# 还可以对新模型进一步训练,可以看到准确率进一步提高了
new_model.fit(train_image, train_label, epochs=3)
Train on 60000 samples
Epoch 1/3
60000/60000 [==============================] - 2s 30us/sample - loss: 0.3151 - acc: 0.8842
Epoch 2/3
60000/60000 [==============================] - 2s 28us/sample - loss: 0.2966 - acc: 0.8910
Epoch 3/3
60000/60000 [==============================] - 2s 29us/sample - loss: 0.2812 - acc: 0.8967





<tensorflow.python.keras.callbacks.History at 0x2db23371d68>

Note:
此方法保存以下所有内容:

  • 权重

  • 模型配置(架构)

  • 优化器配置

3 仅保存架构

  • 有时我们只对模型架构感兴趣,而无需保存权重或优化器;

  • 这种情况下,可以仅保存模型的“配置”。

# config = model.get_config() 得到python字典
json_config = model.to_json() # json格式方便保存到磁盘
json_config
'{'class_name': 'Sequential', 'config': {'name': 'sequential', 'layers': [{'class_name': 'Flatten', 'config': {'name': 'flatten', 'trainable': true, 'batch_input_shape': [null, 28, 28], 'dtype': 'float32', 'data_format': 'channels_last'}}, {'class_name': 'Dense', 'config': {'name': 'dense', 'trainable': true, 'dtype': 'float32', 'units': 128, 'activation': 'relu', 'use_bias': true, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': null}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': null, 'bias_regularizer': null, 'activity_regularizer': null, 'kernel_constraint': null, 'bias_constraint': null}}, {'class_name': 'Dense', 'config': {'name': 'dense_1', 'trainable': true, 'dtype': 'float32', 'units': 10, 'activation': 'softmax', 'use_bias': true, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': null}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': null, 'bias_regularizer': null, 'activity_regularizer': null, 'kernel_constraint': null, 'bias_constraint': null}}]}, 'keras_version': '2.2.4-tf', 'backend': 'tensorflow'}'
# 加载模型架构
# reinitialized_model = tf.keras.models.model_from_config(config)
reinitialized_model = tf.keras.models.model_from_json(json_config)
reinitialized_model.summary()
Model: 'sequential'
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================

flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================

Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

Note:

  • 这个时候想使用这个模型必须进行编译和训练。

  • 可以将配置文件写入磁盘,以备下次读取。

reinitialized_model.compile(optimizer='adam',
                            loss='sparse_categorical_crossentropy',
                            metrics=['acc'])
reinitialized_model.fit(train_image, train_label, epochs=3)
Train on 60000 samples
Epoch 1/3
60000/60000 [==============================] - 2s 32us/sample - loss: 0.4958 - acc: 0.8251
Epoch 2/3
60000/60000 [==============================] - 2s 28us/sample - loss: 0.3741 - acc: 0.8654
Epoch 3/3
60000/60000 [==============================] - 2s 28us/sample - loss: 0.3359 - acc: 0.8784





<tensorflow.python.keras.callbacks.History at 0x2db2364dc50>
reinitialized_model.evaluate(test_image, test_label, verbose=0)
[0.362303579092025760.8703]

4 仅保存权重

  • 有时我们只需要保存模型的状态(其权重值),而对模型架构不感兴趣;

  • 可以通过get_weights()获取权重值,并通过set_weights()设置权重值。

weights = model.get_weights()
weights
[array([[-0.14266095-0.02928461-0.03059634, ..., -0.08731885,
          0.06544596-0.10160343],
        [ 0.00358689-0.00829028-0.02895553, ..., -0.12510738,
          0.03449067-0.09857947],
        [-0.04950375,  0.07115625-0.03697051, ..., -0.07412314,
          0.02556015-0.06952383],
        ...,
        [-0.05137132,  0.02813634,  0.01432281, ..., -0.17519541,
         -0.23348525,  0.08433083],
        [-0.09285939,  0.08466033-0.14937393, ..., -0.10330624,
         -0.12008775,  0.17219253],
        [-0.15010561,  0.03778582-0.1945498 , ..., -0.01276695,
         -0.10709424,  0.02418207]], dtype=float32),
 array([-0.20789446-0.13141602,  0.16348973,  0.02229838,  0.20742084,
         0.24744403,  0.3935284 ,  0.01359873,  0.28906068,  0.16664377,
         0.11691718,  0.10966373-0.0312118 ,  0.29208302-0.06575458,
         0.22923304,  0.08814216,  0.25544012,  0.3141708 , -0.02325695,
         0.46262854,  0.253665  ,  0.1873551 ,  0.30536744,  0.09335868,
         0.12537828,  0.05300334-0.02818783,  0.4608134 , -0.09343797,
         0.17170556,  0.24527179-0.22953849,  0.03436952,  0.06337881,
         0.11532164,  0.01581998-0.06603459,  0.1451305 ,  0.37309167,
        -0.10412953-0.01045883-0.11555314,  0.30166924-0.00355601,
        -0.01543491,  0.34943712-0.03428617,  0.14934501,  0.19140476,
         0.08839089,  0.17634022-0.07683504-0.1476895 , -0.21361239,
         0.16388223,  0.10454813-0.01094854-0.2671146 ,  0.33385256,
        -0.01923951-0.02843694,  0.24996676,  0.10793806,  0.2111573 ,
         0.01344524-0.21662362,  0.05791511,  0.27268228-0.12840772,
         0.22646935-0.15056308-0.0342468 ,  0.07100234-0.07210353,
         0.18772861,  0.26645064,  0.15151191,  0.27598998-0.10759276,
         0.25381055,  0.09859606,  0.1382245 ,  0.11870526,  0.14192112,
         0.33792323,  0.25348547,  0.18790671,  0.24492127-0.01050841,
         0.28550985-0.18192717-0.3675344 , -0.13773952,  0.35116568,
         0.34428683,  0.2417321 ,  0.09293219,  0.24475743,  0.3457114 ,
        -0.01730274,  0.21840273,  0.09212765,  0.35753956,  0.39600298,
        -0.00486282,  0.22293994,  0.30595776-0.17183033-0.02702194,
         0.01568035,  0.26713175,  0.2891485 , -0.06422935-0.05320153,
         0.18866703-0.06518383,  0.27607   , -0.01198964,  0.12543014,
        -0.03183404-0.00545676,  0.12240074-0.01263201-0.01893605,
        -0.2026584 , -0.01084313-0.16518423], dtype=float32),
 array([[ 1.5630949e-01-2.0091803e-01-2.7308524e-01, ...,
          5.3857148e-02,  3.7073534e-02,  1.0175616e-01],
        [-5.4036319e-01-2.6854080e-01-4.1823533e-01, ...,
          2.1708187e-01-4.5078900e-01,  2.6437706e-01],
        [ 2.1154724e-02,  3.5432404e-01,  3.4085352e-02, ...,
         -2.9830724e-01,  2.1905501e-01-2.3593178e-01],
        ...,
        [-2.0776138e-01-8.0822185e-03-1.4333259e-01, ...,
         -1.3110130e-01,  1.4290324e-01,  1.9207600e-01],
        [-8.0469556e-02-4.4808078e-01-3.1039205e-01, ...,
          2.3106417e-01,  1.6513802e-01,  6.3845254e-02],
        [-1.2603492e-04-3.4845769e-02,  2.9384667e-02, ...,
         -2.3625530e-01-4.0597755e-02-3.7576398e-01]], dtype=float32),
 array([ 0.06370834-0.20797089,  0.10423248,  0.12863433-0.23937117,
         0.22634539,  0.08684748,  0.10348023-0.14702898-0.36164775],
       dtype=float32)]
reweighted_model = tf.keras.models.model_from_json(json_config)
# 使用已保存的权重
reweighted_model.set_weights(weights)

Note:

  • 此时需要对模型进行编译,之后才能继续训练和测试;

  • 因为虽然恢复了架构和权重,但还少一个优化器配置。

reweighted_model.compile(optimizer='adam',
                         loss='sparse_categorical_crossentropy',
                         metrics=['acc'])
# 现在又可以接着训练和测试了
reweighted_model.evaluate(test_image, test_label, verbose=0)
[0.37253551368713380.8677]
# 直接将weights保存到磁盘上
model.save_weights('my_weights.h5')
# 从磁盘加载权重
reweighted_model.load_weights('my_weights.h5')
reweighted_model.evaluate(test_image, test_label, verbose=0)
[0.37253551368713380.8677]

5 使用回调函数保存模型

tf.keras.callbacks.ModelCheckpoint

  • 在训练期间或者训练结束时自动保存checkpoint;

  • 这样就可以使用经过训练的模型,而无需重新训练;

  • 或者从上次暂停的地方继续训练,防止训练过程中断。

# 保存路径
checkpoint_path = 'model_cp/cp.ckpt'
# 定义回调函数,可以查看里面的参数说明,这里仅保存权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True)
model_cp = tf.keras.Sequential()
model_cp.add(tf.keras.layers.Flatten(input_shape=(2828)))
model_cp.add(tf.keras.layers.Dense(128, activation='relu'))
model_cp.add(tf.keras.layers.Dense(10, activation='softmax'))
model_cp.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['acc'])
model_cp.fit(train_image, train_label, epochs=3, callbacks=[cp_callback])
Train on 60000 samples
Epoch 1/3
60000/60000 [==============================] - 2s 34us/sample - loss: 0.3136 - acc: 0.8851
Epoch 2/3
60000/60000 [==============================] - 2s 32us/sample - loss: 0.2941 - acc: 0.8904
Epoch 3/3
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2808 - acc: 0.8957





<tensorflow.python.keras.callbacks.History at 0x2db24357710>

Note:

  • 此时目录里应该就有checkpoint文件了。

  • 加载权重的用法与前面一致(model_cp.load_weights(checkpoint_path)),就不多说了。

6 自定义训练中保存checkpoint

model_ud = tf.keras.Sequential()
model_ud.add(tf.keras.layers.Flatten(input_shape=(2828)))
model_ud.add(tf.keras.layers.Dense(128, activation='relu'))
model_ud.add(tf.keras.layers.Dense(10, activation='softmax'))
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
def loss(model, x, y):
    y_ = model(x)
    return loss_func(y, y_)
def train_step(model, images, labels):
    with tf.GradientTape() as t:
        pred = model(images)
        loss_step = loss_func(labels, pred)
    grads = t.gradient(loss_step, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss_step)
    train_accuracy(labels, pred)
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')
dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
dataset = dataset.shuffle(10000).batch(32)
# 定义保存目录及文件名前缀
cp_dir = './customtrain_cp'
cp_prefix = os.path.join(cp_dir, 'nick')
# 定义要保存哪些东西
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model_ud)
def train():
    for epoch in range(5):
        for (batch, (images, labels)) in enumerate(dataset):
            train_step(model_ud, images, labels)
        print('Epoch{} loss is {}'.format(epoch, train_loss.result()))
        print('Epoch{} accuracy is {}'.format(epoch, train_accuracy.result()))
        train_loss.reset_states()
        train_accuracy.reset_states()
        if (epoch + 1) % 2 == 0:
            checkpoint.save(file_prefix=cp_prefix)
train()
Epoch0 loss is 1.767974615097046
Epoch0 accuracy is 0.699833333492279
Epoch1 loss is 1.710296869277954
Epoch1 accuracy is 0.7519999742507935
Epoch2 loss is 1.6218923330307007
Epoch2 accuracy is 0.8409833312034607
Epoch3 loss is 1.6016263961791992
Epoch3 accuracy is 0.8611000180244446
Epoch4 loss is 1.5949829816818237
Epoch4 accuracy is 0.8670333623886108
# 模型加载
checkpoint.restore(tf.train.latest_checkpoint(cp_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2db24369dd8>

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多