Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3745574
  • 博文数量: 356
  • 博客积分: 10458
  • 博客等级: 上将
  • 技术积分: 4734
  • 用 户 组: 普通用户
  • 注册时间: 2008-03-24 14:59
文章分类

全部博文(356)

文章存档

2020年(17)

2019年(9)

2018年(26)

2017年(5)

2016年(11)

2015年(20)

2014年(2)

2013年(17)

2012年(15)

2011年(4)

2010年(7)

2009年(14)

2008年(209)

分类: Python/Ruby

2019-08-31 22:20:58

tensorflow保存模型有多种方法

第一种:saver.save(sess, "./hello_model") # 生成ckpt模型文件, hello_model.data-00000-of-00001  hello_model.index  hello_model.meta

第二种:tf.train.write_graph(sess.graph_def, ./,  'hello.pb') # 生成hello.pb, 再通过freeze_graph把hello.pb与ckpt固化成新的pb文件

第三种:用tf.graph_util.convert_variables_to_constants把变量转成常量之后写入PB文件中

第四种:使用tf.saved_model.builder.SavedModelBuilder

具体看代码,及示例

保存模型的文件 saver_hello.py



点击(此处)折叠或打开

  1. import tensorflow as tf
  2. import sys
  3. import os

  4. # 把变量转成常量之后写入PB文件中
  5. def SaveFrozenPb(nodeNameList, pbFile):
  6.     gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), nodeNameList)
  7.     with tf.gfile.GFile(pbFile, 'wb') as f:
  8.         f.write(gd.SerializeToString())

  9. # 通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件
  10. # freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
  11. # 如果不调用freeze_graph, 直接使用会报错‘google.protobuf.message.DecodeError: Error parsing message’
  12. def SavePbForFreezeGraph(pbDir, pbName):
  13.     tf.train.write_graph(sess.graph_def, pbDir, pbName)

  14. def SaveBuilderPb(pbDir):
  15.     if not os.path.exists(pbDir):
  16.         os.makedirs(pbDir)
  17.     builder = tf.saved_model.builder.SavedModelBuilder(pbDir)
  18.     builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)
  19.     builder.save()

  20. if __name__ == '__main__':
  21.     hello = tf.Variable(tf.constant('Hello World', name = "hello")) # 要save成功,需要tf.Variable, 否则会报错'ValueError: No variables to save'
  22.     x = tf.placeholder(tf.float32, name="x")
  23.     y = tf.multiply(x, 2, name="y")

  24.     init = tf.global_variables_initializer()
  25.     sess = tf.Session()
  26.     sess.run(init)

  27.     saver = tf.train.Saver()

  28.     typeStr = sys.argv[1]
  29.     if typeStr == 'ckpt' or typeStr == 'pbNotFrozen':
  30.         saver.save(sess, "./hello_model", write_meta_graph=True) # hello_model.data-00000-of-00001 hello_model.index hello_model.meta
  31.     if typeStr == 'pbNotFrozen':
  32.         SavePbForFreezeGraph('./', 'hello.pb') # 需要经由freeze_graph工具处理
  33.     elif typeStr == 'pbFrozen':
  34.         SaveFrozenPb(['x', 'y', 'hello'], './hello_frozen.pb') # 无需再经由freeze_graph工具处理
  35.     elif typeStr == 'builderPb':
  36.         SaveBuilderPb('./save/')

加载模型文件restore_hello.py

点击(此处)折叠或打开

  1. import tensorflow as tf
  2. import sys

  3. def RestoreMeta(sess, name):
  4.     #ckpt = tf.train.get_checkpoint_state('./')
  5.     #restore = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
  6.     #restore.restore(sess, ckpt.model_checkpoint_path)
  7.     restore = tf.train.import_meta_graph(name)
  8.     restore.restore(sess, "hello_model")

  9. def RestorePb(sess, name):
  10.     # 二进制读取模型文件
  11.     with tf.gfile.FastGFile(name, 'rb') as f:
  12.        graph_def = tf.GraphDef()
  13.        graph_def.ParseFromString(f.read())
  14.        sess.graph.as_default()
  15.        tf.import_graph_def(graph_def, name='') # 导入计算图

  16. def RestoreBuilderPb(sess, pbDir):
  17.     tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], pbDir)

  18. if __name__ == '__main__':
  19.     sess = tf.Session()
  20.     typeStr = sys.argv[1]
  21.     if typeStr == 'ckpt':
  22.         RestoreMeta(sess, 'hello_model.meta')
  23.     elif typeStr == 'pbFrozen':
  24.         RestorePb(sess, './hello_frozen.pb')
  25.     elif typeStr == 'builderPb':
  26.         RestoreBuilderPb(sess, './save/')

  27.     x = tf.get_default_graph().get_tensor_by_name("x:0")
  28.     y = tf.get_default_graph().get_tensor_by_name("y:0")
  29.     hello = tf.get_default_graph().get_tensor_by_name("hello:0")

  30.     print(sess.run(y, feed_dict={x:5})) # 10.0
  31.     print(sess.run(hello)) # b'Hello World'

第一种:ckpt

保存模型

python3 ./saver_hello.py ckpt

生成checkpoint hello_model.data-00000-of-00001  hello_model.index  hello_model.meta
加载模型

python3 ./restore_hello.py ckpt

运行结果

10.0
b'Hello World'

第二种:ckpt+pb+固化 

python3 ./saver_hello.py pbNotFrozen
生成checkpoint   hello_model.data-00000-of-00001  hello_model.index  hello_model.meta  hello.pb

固化

freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb

加载

python3 ./restore_hello.py pbFrozen

第三种:固化的pb

保存

python3 ./saver_hello.py pbFrozen

加载

python3 ./restore_hello.py pbFrozen


第四种:

python3 ./saver_hello.py builderPb
python3 ./restore_hello.py builderPb

阅读(176950) | 评论(1) | 转发(0) |
给主人留下些什么吧!~~

帅得不敢出门2019-10-11 11:16:48

更正下
def SaveBuilderPb(pbDir):
    if not os.path.exists(pbDir):
        os.makedirs(pbDir)
要改成
def SaveBuilderPb(pbDir):
删除掉2行