Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3711965
  • 博文数量: 356
  • 博客积分: 10458
  • 博客等级: 上将
  • 技术积分: 4734
  • 用 户 组: 普通用户
  • 注册时间: 2008-03-24 14:59
文章分类

全部博文(356)

文章存档

2020年(17)

2019年(9)

2018年(26)

2017年(5)

2016年(11)

2015年(20)

2014年(2)

2013年(17)

2012年(15)

2011年(4)

2010年(7)

2009年(14)

2008年(209)

分类: 大数据

2018-07-05 10:31:51

怎么在新的项目上使用已训练过的模型,以识别花种类为例


现代的识别模型动辙有百万的参数,从头开始训练需要许多标记过的训练数据及训练时间,迁移学习可以很方便的复用已经训练过的模型应用到新的场景上,
以一个在ImageNet上训练过的图片分类模型为例,虽然不如重新训练一个模型精度好,但是对于几千而不是上百万的已打标签的训练数据而言,它工作的很高效,
能直接在一台没有GPU的手提电脑上30分钟跑完。



首先下载hub工程,tensorflow的一些训练过的模型已经移到hub.git上了
git clone https://github.com/tensorflow/hub.git
学习的顺序建议是
先Inception V3 然后再是 NASNet/PNASNet, MobileNet V1 与V2.
sudo pip3 install tensorflow_hub



制作或者准备训练数据,这是几种花的分类图片
cd /opt/flower
curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz
把图片都解压到了/opt/flower/flower_photos下


要执行的是hub/examples/image_retraining/retrain.py
python hub/examples/image_retraining/retrain.py --image_dir /opt/flower/flower_photos
如果tensorflow编译的是python3
python3 hub/examples/image_retraining/retrain.py --image_dir /opt/flower/flower_photos


要看所有参数,请用以下命令查看
python retrain.py -h


1)这个脚本会从网上下载模型:
会从'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1',下载TF-Hub Module 
主要是以下几个文件(在/tmp/tfhub_modules目录下)
saved_model.pb   tfhub_module.pb  variables/variables.data-00000-of-00001  variables/variables.index


2) 加载预训练的模型,然后训练一个新的鲜花分类模型。
默认情况下,此脚本将运行4000个训练步骤。每一步从训练集随机选择十个图像
会有一堆如下这行的打印:
......
INFO:tensorflow:2018-05-23 12:59:03.476464: Step 3999: Train accuracy = 97.0%
INFO:tensorflow:2018-05-23 12:59:03.476647: Step 3999: Cross entropy = 0.111035
INFO:tensorflow:2018-05-23 12:59:03.662786: Step 3999: Validation accuracy = 92.0% (N=100)
这三行就是模型的准确度了,越高自然是越好
INFO:tensorflow:Initialize variable module/InceptionV3/Mixed_7c/Branch_3/Conv2d_0b_1x1/weights:0 from checkpoint b'/tmp/tfhub_modules/11d9faf945d073033780fd924b2b09ff42155763/variables/variables' with InceptionV3/Mixed_7c/Branch_3/Conv2d_0b_1x1/weights
最后是打印这三行,说明结束了
INFO:tensorflow:Restoring parameters from /tmp/_retrain_checkpoint
INFO:tensorflow:Froze 378 variables.
Converted 378 variables to const ops


这过程会生成一些文件在 /tmp/bottleneck下
daisy  dandelion  roses  sunflowers  tulips


3)在训练中或之后使用tensorboard查看训练过程的loss,命中,weight的变化
执行
tensorboard --logdir /tmp/retrain_logs
然后浏览器中输入
localhost:6006
会有图表显示出来


3)
训练后新的模型保存在
/tmp/output_graph.pb
标签在/tmp/output_labels.txt
这个模型结合了TF-Hub module与我们新加入的训练数据,原理是更新了原来模型的最顶层,而保持其他层。


4)有了训练出的模型,就可以拿来用了,加载模型,并输入一张图片,测试结果
先下载测试脚本:
curl -LO https://github.com/tensorflow/tensorflow/raw/master/tensorflow/examples/label_image/label_image.py


然后执行测试脚本加载模型进行测试:
python3 label_image.py \
--graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
--input_layer=Placeholder \
--output_layer=final_result \
--image=/opt/flower/flower_photos/daisy/21652746_cc379e0eea_m.jpg


如果是python2的请把python3改成python
输出如下
daisy 0.9982089
sunflowers 0.0013442119
dandelion 0.00030723808
tulips 0.00010841676
roses 3.1286174e-05


如果要自己生成图片进行训练,可以参考下flower_photos这个图片文件夹,每个子文件夹是一个分类,这个子文件夹名就是分类的名字,
比如rose就是玫瑰,注意要是全小写的且只带数字及字母.
而且每个分类至少要大于100张图片,图片尽量背景要各式各样,如果背景相似会干扰到训练,
比如过多的蓝色门与黑色的门,机器可能会以蓝色与黑色来做为区分。
还有提高精度的方法是,修改参数 --random_crop, --random_scale and --random_brightness
让机器自动对样本图片进行裁剪,缩放,调亮度等操作,生成更多的样本,
或者是指定参数--how_many_training_steps值修改训练次数.
--flip_left_right则是图片镜面对换操作,--learning_rate修改学习率,--train_batch_size批次大小


除了上面的模型,跑另外的模型也可以用retrain.py这个脚本,比如mobilenet
生成模型:
python retrain.py \
    --image_dir ~/flower_photos \
    --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/1


测试模型:
python label_image.py \
--graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
--input_layer=Placeholder \
--output_layer=final_result \
--input_height=224 --input_width=224 \
--image=/opt/flower/flower_photos/daisy/21652746_cc379e0eea_m.jpg


参考https://www.tensorflow.org/tutorials/image_retraining

作者:帅得不敢出门

阅读(1296) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~