Chinaunix首页 | 论坛 | 博客
  • 博客访问: 4519212
  • 博文数量: 356
  • 博客积分: 10458
  • 博客等级: 上将
  • 技术积分: 4734
  • 用 户 组: 普通用户
  • 注册时间: 2008-03-24 14:59
文章分类

全部博文(356)

文章存档

2020年(17)

2019年(9)

2018年(26)

2017年(5)

2016年(11)

2015年(20)

2014年(2)

2013年(17)

2012年(15)

2011年(4)

2010年(7)

2009年(14)

2008年(209)

分类: Python/Ruby

2020-03-24 14:17:46

训练模型并生成ckpt保存到./test目录下
ls ./test/
checkpoint                                   events.out.tfevents.1565972704.103cfd64b10e  model.ckpt-170000.index
events.out.tfevents.1565752875.246d2b4c0eaa  model.ckpt-170000.data-00000-of-00001        model.ckpt-170000.meta

把ckpt转换成saved model模型,保存到./test/saved目录下
ls ./test/saved/
saved_model.pb  variables

尝试冻结模型
freeze_graph --input_saved_model_dir=./test/saved --output_node_names="superpoint/descriptors,superpoint/prob_nms" --saved_model_tags=serve --output_graph=./test/saved/freezed.pb
出错:
I1015 15:07:38.564536 140674071184704 saver.py:1280] Restoring parameters from ./test/saved/variables/variables
Traceback (most recent call last):
  File "/home/zm/tensor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/zm/tensor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1339, in _run_fn
    self._extend_graph()
  File "/home/zm/tensor/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1374, in _extend_graph
    tf_session.ExtendSession(self._session)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation superpoint/train_tower0/transpose_8: {{node superpoint/train_tower0/transpose_8}}was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
         [[superpoint/train_tower0/transpose_8]]

During handling of the above exception, another exception occurred:原因是指定了gpu运算,但是又找不到可用设备,解决方法,传config = tf.ConfigProto(allow_soft_placement=True)到session中, 如果指定的设备不存在,允许TF自动分配设备

点击(此处)折叠或打开

  1.     config = tf.ConfigProto(allow_soft_placement=True)
  2.     with tf.Session(graph=tf.Graph(), config=config) as session:
  3.         # load ckpt
  4.         ckpt = tf.train.get_checkpoint_state(sys.argv[1]) # get ckpt 传ckpt所在目录,此例为./test
  5.         saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # load graph
  6.         saver.restore(session, ckpt.model_checkpoint_path)
  7.  
  8.         # save saved_model
  9.         tf.saved_model.simple_save(session,
  10.                 './test_save', #存放生成模型的目录
  11.                 inputs={
  12.                     "superpoint/image": session.graph.get_tensor_by_name("superpoint/image:0"),
  13.                 },
  14.                 outputs={
  15.                     "superpoint/descriptors": session.graph.get_tensor_by_name("superpoint/descriptors:0"),
  16.                     "superpoint/prob_nms": session.graph.get_tensor_by_name("superpoint/prob_nms:0"),
  17.                 }
  18.             )
查看生成模型大小 ll test_save/ -h
总用量 95M
-rw-r--r-- 1 zm users 95M 10月 15 15:55 saved_model.pb
drwxr-xr-x 2 zm users  66 10月 15 15:55 variables

此时再冻结模型就可以正常通过了
freeze_graph --input_saved_model_dir=./test_save --input_node_names='superpoint/image' --output_node_names='superpoint/descriptors,superpoint/prob_nms' --saved_model_tags='serve' --output_graph=/tmp/tmp_freezed_graph.pb

此时大小已经小了一半ll -h /tmp/tmp_freezed_graph.pb
-rw-r--r-- 1 zm users 59M 10月 15 16:03 /tmp/tmp_freezed_graph.pb

假设tensorflow下载在/opt/tensorflow目录下,接下来则要利用optimize_for_inference.py进行模型优化
python /opt/tensorflow/tensorflow/python/tools/optimize_for_inference.py  --input=/tmp/tmp_freezed_graph.pb --output=/tmp/tmp_optimized_graph.pb --frozen_graph=true --input_names='superpoint/image' --output_names='superpoint/descriptors,superpoint/prob_nms'

此时大小变成5M
ll -h /tmp/tmp_optimized_graph.pb
-rw-r--r-- 1 zm users 5.1M 10月 15 16:09 /tmp/tmp_optimized_graph.pb
最后再把graph文件传回saved model文件,转换代码conver_graph_to_saved_model.py如下 

点击(此处)折叠或打开

  1. import sys
  2. import os
  3. import tensorflow as tf
  4.  
  5. def get_graph_def_from_file(graph_filepath):
  6.  
  7.     print(graph_filepath)
  8.     print("")
  9.  
  10.     from tensorflow.python import ops
  11.  
  12.     with ops.Graph().as_default():
  13.         with tf.gfile.GFile(graph_filepath, "rb") as f:
  14.             graph_def = tf.GraphDef()
  15.             graph_def.ParseFromString(f.read())
  16.  
  17.             return graph_def
  18.  
  19. #把优化的模型 (GraphDef) 转换回 SavedModel
  20. def convert_graph_def_to_saved_model(saved_model_dir, graph_filepath):
  21.  
  22.     from tensorflow.python import ops
  23.     export_dir=os.path.join(saved_model_dir,'optimised')
  24.  
  25.     if tf.gfile.Exists(export_dir):
  26.         tf.gfile.DeleteRecursively(export_dir)
  27.  
  28.     graph_def = get_graph_def_from_file(graph_filepath)
  29.  
  30.     with tf.Session(graph=tf.Graph()) as session:
  31.         tf.import_graph_def(graph_def, name="")
  32.         builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
  33.         builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING], signature_def_map=None, assets_collection=None)
  34.         builder.save(as_text=False)
  35.  
  36.  
  37. if __name__ == '__main__':
  38.     graph_filepath = sys.argv[1]
  39.     saved_model_dir = sys.argv[2]
  40.     convert_graph_def_to_saved_model(saved_model_dir, graph_filepath)

python conver_graph_to_saved_model.py /tmp/tmp_optimized_graph.pb test_save
最终生成的模型在test_save/optimised/
 ll -h test_save/optimised/
总用量 5.1M
-rw-r--r-- 1 zm users 5.1M 10月 15 17:13 saved_model.pb
drwxr-xr-x 2 zm users    6 10月 15 17:13 variables

作者:帅得不敢出门

阅读(2948) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~