本文共 3601 字,大约阅读时间需要 12 分钟。
在做一件事之前,了解到这件事做成之后的样子是非常重要的,所以我们先来看看我们的应用会做成什么效果
我们从相册里面选取一张照片,之后程序就会识别出图片中的物体,我们可以看到在这图上面识别出多个人体,领带, 酒杯和餐桌,并用红色的框标识物体的位置,同时在边框的左上角有识别物体的名称。 注意的是 ‘莱纳德’ 并没有作为一个人体被识别出来,这不是程序的bug, 只是因为这个模型没有办法识别,可能是缺乏相关的训练数据之类的。 我们可以选用识别率更高的模型或者自己训练一个‘莱纳德’识别器来解决这个问题。
我们使用Android 创建一个新的 Android 项目(Empty Activity), Compile SDK Version 25, Min SDK Version 19, Build Tools Version 26.0.2
在上一个课时中我们下载了Inference Interface的nightly build的AAR文件, 我们需要把这个文件导入到我们的项目中,通常我们会把这个AAR文件放在 app/libs下面
为了导入这个AAR, 我们首先需要在app/build.gradle中声明一个本地的flatDir仓库
repositories { flatDir { dirs 'libs' }}
然后指定依赖
compile name: 'tensorflow', ext: 'aar'
最后再做一个Project Sync就完成了Inference Interface的导入, 完整的 app/build.gradle应该是这样的
....repositories { flatDir { dirs 'libs' }}android{ .....}dependencies{ ..... compile name: 'tensorflow', ext: 'aar' .....}
在上一课时中我们已经下载Pre-trained model的二进制包, 解压缩这个包,我们会发现里面有这些文件
其中graph.pbtext, 和model.ckpt.*是我们在训练自己的模型时会用到的文件,在这里我们暂时忽略;frozen_inference_graph.pb 文件正是我们需要的,开箱即用的模型文件, 我们把这个文件作为一个 asset 导入我们的项目中在Android Studio中,我们点击 New > Folder > Assets Folder 创建一个 assets 目录,将frozen_inference_graph.pb 拷贝到 assets目录中,重命名为model.pb
这里需要说明的是,本系列文章中为了演示尽量简单, 将model文件做为asset 一起编译到最终的apk文件中, 一般来说model文件尺寸都比较大,几十M到几百M都有,在真实应用中你可以不把model文件作为apk的一部分, 而是单独部署到外部或者内部存储上面, 然后从这些位置加载model,这样都是OK的。
我们现在还缺一个东西:在机器学习的世界里面, 绝大部分的输入和输出数据都是数字,换句话来说, 当训练这个识别模型的时候,你不会告诉它这张图片上的是人, 而是告诉它这个图片上面的物体代号是1; 模型在输出识别结果的时候,也不会输出人,汽车这样的字符,而是输出1,2,3这样的的数字,那么我们怎么知道1,2,3代表的是什么,我们去哪里找这样的对应关系呢?
我们需要找到在训练这个模型的时候,训练数据中物体代号和物体的对应关系。TensorFlow Object Detection API中的模型训练时使用的是MS COCO的物体数据集合, 我们可以在下载到相应的标签文件, 我们打开这个文件
文件的内容就是数字和物体名称的对应的关系,我们稍微处理一下这个文件,去掉前面的数字, 变成下面这样 然后我们只需要依次将文件的每一行都按顺序读入一个数组, 那么假设模型输出识别结果为3,我们只需要找到这个数组中下标为3的元素,就是这个物体的名称了。我们把这个文件也存到assets目录中,重命名为labels.txt, 现在assets目录应该是这样的
现在相关的资源都导入到项目里面了,我们来写一点代码把模型加载起来吧!
我们需要做2件事:
首先我们将labels.txt的内容依次读入到数组中
Listlabels = new ArrayList<>();InputStream labelsInput = getAssets().open("labels.text");BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { labels.add(line); } br.close();
接着我们加载model.pb
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "model.pb");
我们获取到 TensorFlowInferenceInterface的对象之后就可以在这个对象上面输入图片数据并获取识别结果了
完整的代码如下,这里我写了一个类来进行封装
public class ObjectDetector { private String labelFilename; private String modelFilename; private Listlabels = new ArrayList<>(); private AssetManager assetManager; private TensorFlowInferenceInterface inferenceInterface; public ObjectDetector(String labelFileName, String modelFileName, AssetManager assetManager) { this.labelFilename = labelFileName; this.modelFilename = modelFileName; this.assetManager = assetManager;} public void load() throws IOException { InputStream labelsInput = assetManager.open(labelFilename); BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { labels.add(line); } br.close(); if (inferenceInterface != null) { inferenceInterface.close(); } inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);}
那么我们只需要这样调用来加载模型
detector = new ObjectDetector("labels.txt", "model.pb", getAssets());try { detector.load();} catch (IOException e) { //handle exception}
接下来我们开始输入图片数据开始识别吧!
转载地址:http://bsiti.baihongyu.com/