最近在做的工作要在java中使用keras训练好的模型,但是刚接触这方面的知识,于是在网上找了很相关的博客和资料去看,然后记录一下在这个过程中遇到的一些问题以及解决办法.
1
https://blog.csdn.net/Butertfly/article/details/80952987
关于在keras中如何对数据进行预处理/创建模型/训练模型/保存模型/在java和python中调用,可以参考这篇文章.
1.关于模型保存中输入输出节点名称的问题
我主要是参考上面引用的博客中的代码,但是我发现运行的时候一直报错"",即使复制的一模一样的代码这部分还是出错(自己将上述博客中的代码运行了一遍)
# kera 模型保存为pb文件
sess = K.get_session()
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["dense_3/Softmax"])
#第一种保存方法
with open('model.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
# 第二中保存方法
tf.train.write_graph(frozen_graph_def, 'model', 'test_model.pb', as_text=False)
# 查看模型中输入输出节点的名称和大小
# print(model.input.name, model.input.shape)
# print(model.output.name, model.output.shape)
此部分需要注意的是output_node_names为模型输出节点的名称,如果写错云运行的时候会出现"*** is not in graph"的错误,因此可以使用# print(model.input.name, model.input.shape)和 print(model.output.name, model.output.shape)来查看,这两个节点名称在后续的python和java中调用会用到.
2.python中调用pb文件
import tensorflow as tf
import numpy as np
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
# 读取保存的pb文件,将其解析成对应的GraphDef Protocol Buffer
with open('/home/thm/PycharmProjects/code/Extraction/Attacks_detection/step/model.pb', "rb") as f:
output_graph_def.ParseFromString(f.read())
# graph_def:将graph中保存的图片加载到当前图中,包含要导入到默认图中的操作的GraphDef protocol
# name:放在graph_def名称前面的前缀,但是并不适用于作为输入的名称,默认的是import
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init) # 初始化所有的变量
a = np.array([6, 1082, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2])
# 获取输入输出节点的名称
input_x = sess.graph.get_tensor_by_name("dense_1_input:0")
output = sess.graph.get_tensor_by_name("dense_3/Softmax:0")
# 传入想要计算的参数,将input_x的值替换(feed使用tensorflow值临时替换一个操作的输入参数,从而替换原来的输出结果)
result = sess.run(output, feed_dict={input_x: a.reshape(1, 15)})
print(result)
Class_dict = {'BENIGN': 0, 'syn_flood': 1, 'icmp_flood': 2, 'udp_flood': 3, 'sarfu': 4}
species_dict = {v: k for k, v in Class_dict.items()}
# 代码中v:k代表了v是key,k是value,而k,v则是表示key,value换个位置
print("\nPredicted species is: ")
print(species_dict[np.argmax(result)])
3.java中调用pb文件
首先要添加依赖关系
方法一: 右击项目/Maven/Add Dependency,在出现的对话框中添加依赖关系(代码中的groupid/artifactid/version和对话框中的三部分一一对应).
方法二:打开项目中的pox.xml文件,将下面的代码直接粘贴过去.
特别注意:导入的tensorflow包的version对应的是你自己tensroflow的版本
(在终端下查看tensorflow版本号
python
import tensorflow as tf
(如果输入该命令之后提示ModuleNotFoundError: No module named ‘tensorflow’,应该是没有在tensorflow环境下运行python,因此先激活tensorflow
source activate tensorflow
再运行python
)
tf.version)
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>proto</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni</artifactId>
<version>1.10.0</version>
</dependency>
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.FloatBuffer;
import org.apache.commons.io.IOUtils;
import org.tensorflow.*;
public class Test_model {
public static String PB_FILE_PATH = "pb模型保存的位置";
public static String INPUT_TENSOR_NAME = "dense_1_input:0";//前面提到的输入输出节点的名称
public static String OUTPUT_TENSOR_NAME = "dense_3/Softmax:0";
public static void main(String[] args) throws IOException {
try (Graph graph = new Graph()) {
//导入图
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(PB_FILE_PATH));
graph.importGraphDef(graphBytes);
float[] a = new float[]{6, 1082, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2};
long[] shape = new long[]{1,15};
Tensor<?> data = Tensor.create(shape, FloatBuffer.wrap(a));
//根据图建立Session
try (Session session = new Session(graph)) {
//相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0})
Tensor<?> out = session.runner()
.feed(INPUT_TENSOR_NAME, data)
.fetch(OUTPUT_TENSOR_NAME).run().get(0);
System.out.println(out);
}
}
}
}
4.java中在原有项目添加完依赖关系后可能会出现"Exception in thread “main” java.lang.NoClassDefFoundError:"
此问题的具体解决方法可以参考这篇博客,讲的挺详细的https://blog.csdn.net/lz6363/article/details/82561292
————————————————
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/weixin_44945845/article/details/108430308