介绍
机器学习正在全球范围内得到普及和使用。它已经彻底改变了某些应用程序的构建方式,并且很可能将继续成为我们日常生活的重要组成部分(并且在不断增加)。
没有糖衣,机器学习并不简单。这相当艰巨,对许多人来说似乎非常复杂。
诸如Google之类的公司全力以赴,将机器学习概念带给开发人员更近的地方,并允许他们在主要帮助下逐步迈出第一步。
这样,诸如TensorFlow之类的框架就诞生了。
什么是TensorFlow?
TensorFlow是Google用Python和C ++开发的开源机器学习框架。
它可以帮助开发人员轻松获取数据,准备和训练模型,预测未来状态以及进行大规模机器学习。
有了它,我们可以训练和运行最常用于光学字符识别,图像识别/分类,自然语言处理等的深度神经网络。
张量和运算
TensorFlow基于计算图,您可以将其想象为具有节点和边的经典图。
每个节点都称为操作,它们将零个或多个张量引入,并产生零个或多个张量。一个操作可能非常简单,例如基本加法,但也可能非常复杂。
张量被描绘为图形的边缘,并且是核心数据单元。在将这些张量馈送给操作时,我们在这些张量上执行不同的功能。它们可以具有一个或多个维度,有时也称为它们的等级-(标量:等级0,向量:等级1,矩阵:等级2)
这些数据通过张量流经计算图,并受运算影响,因此命名为TensorFlow。
张量可以存储任意维度的数据,并且有三种主要类型的张量:占位符,变量和常量。
安装TensorFlow
使用Maven,安装TensorFlow就像添加依赖项一样简单:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.13.1</version> </dependency>
如果您的设备支持GPU支持,请使用以下依赖项:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow</artifactId> <version>1.13.1</version> </dependency> <dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow_jni_gpu</artifactId> <version>1.13.1</version> </dependency>
您可以使用TensorFlow
对象检查当前安装的TensorFlow的版本:
System.out.println(TensorFlow.version());
TensorFlow Java API
org.tensorflow
包装内包含Java API TensorFlow提供的内容。目前尚处于实验阶段,因此不能保证其稳定。
请注意,TensorFlow唯一受完全支持的语言是Python,而Java API几乎没有功能。
它向我们介绍了新的类,接口,枚举和异常。
类
通过API引入的新类是:
-
Graph
:表示TensorFlow计算的数据流程图 -
Operation
:在张量上执行计算的Graph节点 -
OperationBuilder
:操作的构建器类 -
Output<T>
:由操作产生的张量的符号句柄 -
SavedModelBundle
:代表从存储库加载的模型。 -
SavedModelBundle.Loader
:提供用于加载SavedModel的选项 -
Server
:进程内TensorFlow服务器,用于分布式培训 -
Session
:图形执行的驱动程序 -
Session.Run
:执行会话时获得的输出张量和元数据 -
Session.Runner
:运行操作并评估张量 -
Shape
:由操作产生的张量的可能部分已知的形状 -
Tensor<T>
:一个静态类型化的多维数组,其元素为T描述的类型 -
TensorFlow
:描述TensorFlow运行时的静态实用程序方法 -
Tensors
:用于创建Tensor对象的类型安全的工厂方法
枚举
-
DataType
:将Tensor中的元素类型表示为枚举
接口
-
Operand<T>
:由TensorFlow操作的操作数实现的接口
例外
-
TensorFlowException
:执行TensorFlow Graph时引发未经检查的异常
如果将所有这些与Python中的tf模块进行比较,则存在明显的区别。至少到目前为止,Java API的功能数量几乎不相同。
图
如前所述,TensorFlow基于计算图org.tensorflow.Graph
-Java的实现在哪里。
订阅我们的新闻
在收件箱中获取临时教程,指南和作业。从来没有垃圾邮件。随时退订。
订阅电子报
订阅
注意:尽管实例完成后我们需要显式释放Graph使用的资源,但它的实例是线程安全的。
让我们从一个空图开始:
Graph graph = new Graph();
该图没有太大意义,它是空的。要对其进行任何处理,我们首先需要使用Operation
s加载它。
为了使用操作加载它,我们使用opBuilder()
方法,该方法返回一个OperationBuilder
对象,一旦调用该.build()
方法,该对象便会将操作添加到图形中。
常量
让我们在图表中添加一个常量:
Operation x = graph.opBuilder("Const", "x") .setAttr("dtype", DataType.FLOAT) .setAttr("value", Tensor.create(3.0f)) .build();
占位符
占位符是变量的“类型”,在声明时没有值。它们的值将在以后分配。这使我们可以在没有任何实际数据的情况下使用操作来构建图形:
Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.FLOAT) .build();
函数
现在最后,要对此进行总结,我们需要添加某些功能。这些可以像乘法,除法或加法一样简单,也可以像矩阵乘法一样复杂。与之前相同,我们使用.opBuilder()
方法定义函数:
Operation xy = graph.opBuilder("Mul", "xy") .addInput(x.output(0)) .addInput(y.output(0)) .build();
注意:我们使用output(0)
的张量可以具有多个输出。
图可视化
令人遗憾的是,Java API尚未包括任何工具来像Python中那样可视化图形。Java API更新后,本文也会更新。
会话(Sessions)
如前所述,aSession
是执行的驱动Graph
程序。它封装了执行Operation
s和Graph
s来计算Tensor
s的环境。
这意味着,我们构建的图中的张量实际上不具有任何值,因为我们没有在会话中运行该图。
首先,将图形添加到会话中:
Session session = new Session(graph);
我们的计算只是将x
andy
值相乘。为了运行我们的图并计算它,我们fetch()
进行xy
运算并将其x
和y
值输入:
Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); System.out.println(tensor.floatValue());
运行这段代码将产生:
10.0f
用Python保存模型并用Java加载
这听起来有些奇怪,但是由于Python是唯一受支持的语言,因此Java API仍然没有保存模型的功能。
这意味着Java API仅用于服务用例,至少在TensorFlow完全支持之前。至少,我们可以使用以下SavedModelBundle
类在Python中训练和保存模型,然后在Java中加载它们以为它们提供服务:
SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); System.out.println(tensor.floatValue());
结论
TensorFlow是一个功能强大,健壮且广泛使用的框架。它一直在不断改进,最近又引入了新语言-包括Java和JavaScript。
尽管Java API的功能还不如TensorFlow for Python那么多,但它仍然可以作为TensorFlow for Java开发人员的良好入门。
转自:https://zhuanlan.zhihu.com/p/343068127