tensorflow保存模型有多种方法
第一种:saver.save(sess, "./hello_model") # 生成ckpt模型文件, hello_model.data-00000-of-00001 hello_model.index hello_model.meta
第二种:tf.train.write_graph(sess.graph_def, ./, 'hello.pb') # 生成hello.pb, 再通过freeze_graph把hello.pb与ckpt固化成新的pb文件
第三种:用tf.graph_util.convert_variables_to_constants把变量转成常量之后写入PB文件中
第四种:使用tf.saved_model.builder.SavedModelBuilder
具体看代码,及示例
保存模型的文件 saver_hello.py
-
import tensorflow as tf
-
import sys
-
import os
-
-
# 把变量转成常量之后写入PB文件中
-
def SaveFrozenPb(nodeNameList, pbFile):
-
gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), nodeNameList)
-
with tf.gfile.GFile(pbFile, 'wb') as f:
-
f.write(gd.SerializeToString())
-
-
# 通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件
-
# freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
-
# 如果不调用freeze_graph, 直接使用会报错‘google.protobuf.message.DecodeError: Error parsing message’
-
def SavePbForFreezeGraph(pbDir, pbName):
-
tf.train.write_graph(sess.graph_def, pbDir, pbName)
-
-
def SaveBuilderPb(pbDir):
-
if not os.path.exists(pbDir):
-
os.makedirs(pbDir)
-
builder = tf.saved_model.builder.SavedModelBuilder(pbDir)
-
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)
-
builder.save()
-
-
if __name__ == '__main__':
-
hello = tf.Variable(tf.constant('Hello World', name = "hello")) # 要save成功,需要tf.Variable, 否则会报错'ValueError: No variables to save'
-
x = tf.placeholder(tf.float32, name="x")
-
y = tf.multiply(x, 2, name="y")
-
-
init = tf.global_variables_initializer()
-
sess = tf.Session()
-
sess.run(init)
-
-
saver = tf.train.Saver()
-
-
typeStr = sys.argv[1]
-
if typeStr == 'ckpt' or typeStr == 'pbNotFrozen':
-
saver.save(sess, "./hello_model", write_meta_graph=True) # hello_model.data-00000-of-00001 hello_model.index hello_model.meta
-
if typeStr == 'pbNotFrozen':
-
SavePbForFreezeGraph('./', 'hello.pb') # 需要经由freeze_graph工具处理
-
elif typeStr == 'pbFrozen':
-
SaveFrozenPb(['x', 'y', 'hello'], './hello_frozen.pb') # 无需再经由freeze_graph工具处理
-
elif typeStr == 'builderPb':
-
SaveBuilderPb('./save/')
加载模型文件restore_hello.py
-
import tensorflow as tf
-
import sys
-
-
def RestoreMeta(sess, name):
-
#ckpt = tf.train.get_checkpoint_state('./')
-
#restore = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
-
#restore.restore(sess, ckpt.model_checkpoint_path)
-
restore = tf.train.import_meta_graph(name)
-
restore.restore(sess, "hello_model")
-
-
def RestorePb(sess, name):
-
# 二进制读取模型文件
-
with tf.gfile.FastGFile(name, 'rb') as f:
-
graph_def = tf.GraphDef()
-
graph_def.ParseFromString(f.read())
-
sess.graph.as_default()
-
tf.import_graph_def(graph_def, name='') # 导入计算图
-
-
def RestoreBuilderPb(sess, pbDir):
-
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], pbDir)
-
-
if __name__ == '__main__':
-
sess = tf.Session()
-
typeStr = sys.argv[1]
-
if typeStr == 'ckpt':
-
RestoreMeta(sess, 'hello_model.meta')
-
elif typeStr == 'pbFrozen':
-
RestorePb(sess, './hello_frozen.pb')
-
elif typeStr == 'builderPb':
-
RestoreBuilderPb(sess, './save/')
-
-
x = tf.get_default_graph().get_tensor_by_name("x:0")
-
y = tf.get_default_graph().get_tensor_by_name("y:0")
-
hello = tf.get_default_graph().get_tensor_by_name("hello:0")
-
-
print(sess.run(y, feed_dict={x:5})) # 10.0
-
print(sess.run(hello)) # b'Hello World'
第一种:ckpt
保存模型
python3 ./saver_hello.py ckpt
生成checkpoint hello_model.data-00000-of-00001 hello_model.index hello_model.meta
加载模型
python3 ./restore_hello.py ckpt
运行结果
10.0
b'Hello World'
第二种:ckpt+pb+固化
python3 ./saver_hello.py pbNotFrozen
生成checkpoint hello_model.data-00000-of-00001 hello_model.index hello_model.meta hello.pb
固化
freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
加载
python3 ./restore_hello.py pbFrozen
第三种:固化的pb
保存
python3 ./saver_hello.py pbFrozen
加载
python3 ./restore_hello.py pbFrozen
第四种:
python3 ./saver_hello.py builderPb
python3 ./restore_hello.py builderPb
阅读(181525) | 评论(1) | 转发(0) |