分享

keras训练的h5模型转换为pb模型...

 行走在理想边缘 2022-05-27 发布于四川

最近使用keras训练了一个图像分割的模型(.h5),但最终需要在C++中调用该模型,由于keras没有C++接口,所以需要将.h5模型转换为.pb模型后通过tensorflow C++接口进行调用。

由于本人之前接触深度学习较少,很多东西不是很懂,所以在转换过程中遇到了很多问题,在此记录,共同学习。

1、转换之前需要注意的点

  • 本人在转换过程中发现tensorflow1.x和2.x存在区别,所以在转换之前最好确定训练模型时使用的tensorflow版本,转换过程使用的环境尽量和训练模型使用的环境保持一致,不然容易产生很多错误。
  • 导出模型时使用的API,是使用tensorflow.keras还是keras,和转换时使用的API保持一致,因为两者可能不兼容从而导致错误。
  • tensorflow版本:明确当前使用的是1.x还是2.x版本
  • keras版本:与tensorflow版本相匹配,查看tensorflow版本与keras版本对应关系
  • tensorflow各版本whl下载地址

2、tensorflow1.x转换方法

使用环境说明:python3.6.12+tensorflow1.15.0+keras2.2.4

1、常用的方法基本均来自github:keras_to_tensorflow,直接clone下来使用即可。
使用方法:
(1)命令行方式:input_model输入.h5路径,output输入保存.pb路径。

python keras_to_tensorflow.py 
    --input_model="./model.h5" 
    --output_model="./model.pb"

(2)将参数写入代码中,方便调试。
在这里插入图片描述
2、简化后的转换代码

from tensorflow.python.keras.models import load_model   #from keras.models import load_model
import tensorflow as tf
from tensorflow.python.keras import backend as K        #from keras import backend as K
from tensorflow.python.framework import graph_io
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph
 
 
"""----------------------------------配置路径-----------------------------------"""
h5_model_path='./unet.h5'  #填写.h5路径
output_path='.'
pb_model_name='unet.pb'    #填写保存.pb路径
 
 
"""----------------------------------导入keras模型------------------------------"""
K.set_learning_phase(0)
net_model = load_model(h5_model_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)
 
"""----------------------------------保存为.pb格式------------------------------"""
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)

3、常见错误

(1) TypeError:Keras symbolic inputs/outputs do not implement op
在这里插入图片描述
解决方法:在使用tensorflow1.x时不会出现该错误,若使用tensorflow2.x会出现。此时可切换至tensorflow1.x;或将.op直接删除,修改为net_model.output.name也行,但是大概率还会在其它地方报错…
所以最好还是使用tensorflow1.x,若上述方法解决不了,且需要tensorflow2.x,请看tensorflow2.x转换方法。

(2) TypeError: __init__() got an unexpected keyword argument 'ragged'
在这里插入图片描述
解决方法:可能是因为导出模型和转换模型使用的API不一致导致,查看导出模型使用的API属于tensorflow.keras还是keras。若转换模型代码中使用:

from keras.models import load_model
from keras import backend as K

修改为:

from tensorflow.python.keras.models import load_model
from tensorflow.python.keras import backend as K

(3) TypeError: Unknown layer: Functional
在这里插入图片描述
解决方法:我报错的原因是因为模型使用tensorflow2.4.1训练,转换时使用tensorflow1.15.0,导致该错误,主要问题还是因为环境不一致。
或者是模型中存在自定义层,需要在load_model中显式指出,解决该问题的方法是在load_model函数中添加custom_objects参数,该参数接受一个字典,键值为自定义的层(参考博客)。

(4) TypeError: ('Unrecognized keyword arguments:', dict_keys(['ragged']))

解决方法:可能是因为tensorflow版本太低,建议使用更高版本。

(5) TypeError: module 'tensorflow' has no attribute 'global_variables'
在这里插入图片描述

解决方法:当前tensorflow版本为2.x,可尝试如下修改方式。

tf.compat.v1.global_variables()   #类似错误均可将tf修改为tf.compat.v1

(6) TypeError: *** is not in graph
在这里插入图片描述
解决方法:错误说明:输出节点不在图中(如图中conv2d_18/Relu:0)。

  • convert_variables_to_constants函数用来指定保存的 节点名称 而不是 张量的名称 , “conv2d_18/Relu:0” 是张量的名称而 “conv2d_18/Relu” 表示的是节点的名称。尝试将参数output_names直接指定为“conv2d_18/Relu”,如:
output_names= ['conv2d_18/Relu']
frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                              output_names, freeze_var_names)
  • 确认输出节点名称是否正确。
  • 若上述方法无法解决,大概率还是tensorflow版本不一致导致。

3、tensorflow2.x转换方法

使用环境说明:python3.8.5+tensorflow2.4.1+keras2.4.0

1、转换代码

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def h5_to_pb(h5_save_path):
    model = tf.keras.models.load_model(h5_save_path, compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./pb",
                      name="model.pb",
                      as_text=False)   #可设置.pb存储路径


h5_to_pb('./unet.h5')   #此处填入.h5路径

4、python环境下测试pb

需要根据实际情况进行修改,需要修改的地方在代码中已经标出。

import tensorflow as tf
import numpy as np
import cv2

def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.compat.v1.GraphDef()
 
        # 打开.pb模型
        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tensors = tf.import_graph_def(output_graph_def, name="")
            print("tensors:",tensors)
 
        # 在一个session中去run一个前向
        with tf.compat.v1.Session() as sess:
            init = tf.compat.v1.global_variables_initializer()
            sess.run(init)
 
            op = sess.graph.get_operations()
 
            # 打印图中有的操作
            for i,m in enumerate(op):
                print('op{}:'.format(i),m.values())
 
            input_x = sess.graph.get_tensor_by_name("Input:0")  # 具体名称看上一段代码的input.name
            print("input_X:",input_x)
 
            out_softmax = sess.graph.get_tensor_by_name("Identity:0")  # 具体名称看上一段代码的output.name
            print("Output:",out_softmax)
 
            # 读入图片
            img = cv2.imread(jpg_path, 1)    #注意读入灰度图还是彩图
            img=cv2.resize(img,(512,512))    #需要和训练模型时图像resize大小保持一致
            #img=img.astype(np.float32)
            #img=1-img/255;
            # img=np.reshape(img,(1,28,28,1))
            print("img data type:",img.dtype)
 
            # 显示图片内容
            # for row in range(512):
            #     for col in range(512):
            #         if col!=511:
            #             print(img[row][col],' ',end='')
            #         else:
            #             print(img[row][col])
 
            img_out_softmax = sess.run(out_softmax,
                                       feed_dict={input_x: np.reshape(img,(1,512,512,3))})    #图像大小保持一致

            #转换为可保存图像
            show_image = img_out_softmax.reshape(512, 512, 3)          
            show_image = show_image.astype(np.uint8)        
            cv2.imwrite('result.jpg',show_image)
 
            print("img_out_softmax:", img_out_softmax)
            for i,prob in enumerate(img_out_softmax[0]):
                print('class {} prob:{}'.format(i,prob))
            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print("Final class if:",prediction_labels)
            print("prob of label:",img_out_softmax[0,prediction_labels])
 
pb_path = './model.pb'   #pb路径
img = '1107.bmp'         #图像路径
recognize(img, pb_path)

5、总结

综上所述:大多数问题均是由tensorflow版本不一致导致,在解决bug之前最好将版本对应,这样可能会很快解决你的问题。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多