博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow on Android(3): Demo展示和准备工作
阅读量:4148 次
发布时间:2019-05-25

本文共 3601 字,大约阅读时间需要 12 分钟。

Demo

在做一件事之前,了解到这件事做成之后的样子是非常重要的,所以我们先来看看我们的应用会做成什么效果

enter image description here

我们从相册里面选取一张照片,之后程序就会识别出图片中的物体,我们可以看到在这图上面识别出多个人体,领带, 酒杯和餐桌,并用红色的框标识物体的位置,同时在边框的左上角有识别物体的名称。 注意的是 ‘莱纳德’ 并没有作为一个人体被识别出来,这不是程序的bug, 只是因为这个模型没有办法识别,可能是缺乏相关的训练数据之类的。 我们可以选用识别率更高的模型或者自己训练一个‘莱纳德’识别器来解决这个问题。

创建Android工程

我们使用Android 创建一个新的 Android 项目(Empty Activity), Compile SDK Version 25, Min SDK Version 19, Build Tools Version 26.0.2

导入Inference Interface

在上一个课时中我们下载了Inference Interface的nightly build的AAR文件, 我们需要把这个文件导入到我们的项目中,通常我们会把这个AAR文件放在 app/libs下面

enter image description here

为了导入这个AAR, 我们首先需要在app/build.gradle中声明一个本地的flatDir仓库

repositories {    flatDir {       dirs 'libs'  }}

然后指定依赖

compile name: 'tensorflow', ext: 'aar'

最后再做一个Project Sync就完成了Inference Interface的导入, 完整的 app/build.gradle应该是这样的

....repositories {   flatDir {      dirs 'libs'   }}android{  .....}dependencies{    .....    compile name: 'tensorflow', ext: 'aar'    .....}

导入Pre-trained Model

在上一课时中我们已经下载Pre-trained model的二进制包, 解压缩这个包,我们会发现里面有这些文件

enter image description here
其中graph.pbtext, 和model.ckpt.*是我们在训练自己的模型时会用到的文件,在这里我们暂时忽略;frozen_inference_graph.pb 文件正是我们需要的,开箱即用的模型文件, 我们把这个文件作为一个 asset 导入我们的项目中

在Android Studio中,我们点击 New > Folder > Assets Folder 创建一个 assets 目录,将frozen_inference_graph.pb 拷贝到 assets目录中,重命名为model.pb

这里需要说明的是,本系列文章中为了演示尽量简单, 将model文件做为asset 一起编译到最终的apk文件中, 一般来说model文件尺寸都比较大,几十M到几百M都有,在真实应用中你可以不把model文件作为apk的一部分, 而是单独部署到外部或者内部存储上面, 然后从这些位置加载model,这样都是OK的。

我们现在还缺一个东西:在机器学习的世界里面, 绝大部分的输入和输出数据都是数字,换句话来说, 当训练这个识别模型的时候,你不会告诉它这张图片上的是人, 而是告诉它这个图片上面的物体代号是1; 模型在输出识别结果的时候,也不会输出人,汽车这样的字符,而是输出1,2,3这样的的数字,那么我们怎么知道1,2,3代表的是什么,我们去哪里找这样的对应关系呢?

我们需要找到在训练这个模型的时候,训练数据中物体代号和物体的对应关系。TensorFlow Object Detection API中的模型训练时使用的是MS COCO的物体数据集合, 我们可以在下载到相应的标签文件, 我们打开这个文件

enter image description here
文件的内容就是数字和物体名称的对应的关系,我们稍微处理一下这个文件,去掉前面的数字, 变成下面这样
enter image description here
然后我们只需要依次将文件的每一行都按顺序读入一个数组, 那么假设模型输出识别结果为3,我们只需要找到这个数组中下标为3的元素,就是这个物体的名称了。

我们把这个文件也存到assets目录中,重命名为labels.txt, 现在assets目录应该是这样的

enter image description here

写点真正的代码

现在相关的资源都导入到项目里面了,我们来写一点代码把模型加载起来吧!

我们需要做2件事:

  1. 把labels.txt的内容读到数组中,供查询识别结果中的物体名称
  2. 加载model.pb, 获取一个TensorFlowInferenceInterface来进行后续操作

首先我们将labels.txt的内容依次读入到数组中

List
labels = new ArrayList<>();InputStream labelsInput = getAssets().open("labels.text");BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { labels.add(line); } br.close();

接着我们加载model.pb

TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "model.pb");

我们获取到 TensorFlowInferenceInterface的对象之后就可以在这个对象上面输入图片数据并获取识别结果了

完整的代码如下,这里我写了一个类来进行封装

public class ObjectDetector {   private String labelFilename;   private String modelFilename;   private List
labels = new ArrayList<>(); private AssetManager assetManager; private TensorFlowInferenceInterface inferenceInterface; public ObjectDetector(String labelFileName, String modelFileName, AssetManager assetManager) { this.labelFilename = labelFileName; this.modelFilename = modelFileName; this.assetManager = assetManager;} public void load() throws IOException { InputStream labelsInput = assetManager.open(labelFilename); BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { labels.add(line); } br.close(); if (inferenceInterface != null) { inferenceInterface.close(); } inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);}

那么我们只需要这样调用来加载模型

detector = new ObjectDetector("labels.txt", "model.pb", getAssets());try {      detector.load();} catch (IOException e) {     //handle exception}

接下来我们开始输入图片数据开始识别吧!

转载地址:http://bsiti.baihongyu.com/

你可能感兴趣的文章
RMRK筹集600万美元,用于在Polkadot上建立先进的NFT系统标准
查看>>
JavaSE_day14 集合中的Map集合_键值映射关系
查看>>
异常 Java学习Day_15
查看>>
Mysql初始化的命令
查看>>
MySQL关键字的些许问题
查看>>
浅谈HTML
查看>>
css基础
查看>>
Servlet进阶和JSP基础
查看>>
servlet中的cookie和session
查看>>
过滤器及JSP九大隐式对象
查看>>
软件(项目)的分层
查看>>
菜单树
查看>>
【Python】学习笔记——-6.2、使用第三方模块
查看>>
【Python】学习笔记——-7.0、面向对象编程
查看>>
【Python】学习笔记——-7.2、访问限制
查看>>
【Python】学习笔记——-7.3、继承和多态
查看>>
【Python】学习笔记——-7.5、实例属性和类属性
查看>>
git中文安装教程
查看>>
虚拟机 CentOS7/RedHat7/OracleLinux7 配置静态IP地址 Ping 物理机和互联网
查看>>
Jackson Tree Model Example
查看>>