训练模型并生成ckpt保存到./test目录下
ls ./test/
checkpoint events.out.tfevents.1565972704.103cfd64b10e model.ckpt-170000.index
events.out.tfevents.1565752875.246d2b4c0eaa model.ckpt-170000.data-00000-of-00001 model.ckpt-170000.meta
把ckpt转换成saved model模型,保存到./test/saved目录下
ls ./test/saved/
saved_model.pb variables
尝试冻结模型
freeze_graph --input_saved_model_dir=./test/saved --output_node_names="superpoint/descriptors,superpoint/prob_nms" --saved_model_tags=serve --output_graph=./test/saved/freezed.pb
出错:
I1015 15:07:38.564536 140674071184704 saver.py:1280] Restoring parameters from ./test/saved/variables/variables
Traceback (most recent call last):
File "/home/zm/tensor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
return fn(*args)
File "/home/zm/tensor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1339, in _run_fn
self._extend_graph()
File "/home/zm/tensor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1374, in _extend_graph
tf_session.ExtendSession(self._session)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation superpoint/train_tower0/transpose_8: {{node superpoint/train_tower0/transpose_8}}was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
[[superpoint/train_tower0/transpose_8]]
During handling of the above exception, another exception occurred:原因是指定了gpu运算,但是又找不到可用设备,解决方法,传config = tf.ConfigProto(allow_soft_placement=True)到session中, 如果指定的设备不存在,允许TF自动分配设备
-
config = tf.ConfigProto(allow_soft_placement=True)
-
with tf.Session(graph=tf.Graph(), config=config) as session:
-
# load ckpt
-
ckpt = tf.train.get_checkpoint_state(sys.argv[1]) # get ckpt 传ckpt所在目录,此例为./test
-
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # load graph
-
saver.restore(session, ckpt.model_checkpoint_path)
-
-
# save saved_model
-
tf.saved_model.simple_save(session,
-
'./test_save', #存放生成模型的目录
-
inputs={
-
"superpoint/image": session.graph.get_tensor_by_name("superpoint/image:0"),
-
},
-
outputs={
-
"superpoint/descriptors": session.graph.get_tensor_by_name("superpoint/descriptors:0"),
-
"superpoint/prob_nms": session.graph.get_tensor_by_name("superpoint/prob_nms:0"),
-
}
-
)
查看生成模型大小 ll test_save/ -h
总用量 95M
-rw-r--r-- 1 zm users 95M 10月 15 15:55 saved_model.pb
drwxr-xr-x 2 zm users 66 10月 15 15:55 variables
此时再冻结模型就可以正常通过了
freeze_graph --input_saved_model_dir=./test_save --input_node_names='superpoint/image' --output_node_names='superpoint/descriptors,superpoint/prob_nms' --saved_model_tags='serve' --output_graph=/tmp/tmp_freezed_graph.pb
此时大小已经小了一半ll -h /tmp/tmp_freezed_graph.pb
-rw-r--r-- 1 zm users 59M 10月 15 16:03 /tmp/tmp_freezed_graph.pb
假设tensorflow下载在/opt/tensorflow目录下,接下来则要利用optimize_for_inference.py进行模型优化
python /opt/tensorflow/tensorflow/python/tools/optimize_for_inference.py --input=/tmp/tmp_freezed_graph.pb --output=/tmp/tmp_optimized_graph.pb --frozen_graph=true --input_names='superpoint/image' --output_names='superpoint/descriptors,superpoint/prob_nms'
此时大小变成5M
ll -h /tmp/tmp_optimized_graph.pb
-rw-r--r-- 1 zm users 5.1M 10月 15 16:09 /tmp/tmp_optimized_graph.pb
最后再把graph文件传回saved model文件,转换代码conver_graph_to_saved_model.py如下
-
import sys
-
import os
-
import tensorflow as tf
-
-
def get_graph_def_from_file(graph_filepath):
-
-
print(graph_filepath)
-
print("")
-
-
from tensorflow.python import ops
-
-
with ops.Graph().as_default():
-
with tf.gfile.GFile(graph_filepath, "rb") as f:
-
graph_def = tf.GraphDef()
-
graph_def.ParseFromString(f.read())
-
-
return graph_def
-
-
#把优化的模型 (GraphDef) 转换回 SavedModel
-
def convert_graph_def_to_saved_model(saved_model_dir, graph_filepath):
-
-
from tensorflow.python import ops
-
export_dir=os.path.join(saved_model_dir,'optimised')
-
-
if tf.gfile.Exists(export_dir):
-
tf.gfile.DeleteRecursively(export_dir)
-
-
graph_def = get_graph_def_from_file(graph_filepath)
-
-
with tf.Session(graph=tf.Graph()) as session:
-
tf.import_graph_def(graph_def, name="")
-
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
-
builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING], signature_def_map=None, assets_collection=None)
-
builder.save(as_text=False)
-
-
-
if __name__ == '__main__':
-
graph_filepath = sys.argv[1]
-
saved_model_dir = sys.argv[2]
-
convert_graph_def_to_saved_model(saved_model_dir, graph_filepath)
python conver_graph_to_saved_model.py /tmp/tmp_optimized_graph.pb test_save
最终生成的模型在test_save/optimised/
ll -h test_save/optimised/
总用量 5.1M
-rw-r--r-- 1 zm users 5.1M 10月 15 17:13 saved_model.pb
drwxr-xr-x 2 zm users 6 10月 15 17:13 variables
作者:帅得不敢出门
阅读(2948) | 评论(0) | 转发(0) |