最近在做的工作要在java中使用keras训练好的模型,但是刚接触这方面的知识,于是在网上找了很相关的博客和资料去看,然后记录一下在这个过程中遇到的一些问题以及解决办法.

1

https://blog.csdn.net/Butertfly/article/details/80952987

关于在keras中如何对数据进行预处理/创建模型/训练模型/保存模型/在java和python中调用,可以参考这篇文章.


1.关于模型保存中输入输出节点名称的问题

我主要是参考上面引用的博客中的代码,但是我发现运行的时候一直报错"",即使复制的一模一样的代码这部分还是出错(自己将上述博客中的代码运行了一遍)


# kera 模型保存为pb文件

sess = K.get_session()

frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["dense_3/Softmax"])

#第一种保存方法

with open('model.pb', 'wb') as f:

    f.write(frozen_graph_def.SerializeToString())

# 第二中保存方法

tf.train.write_graph(frozen_graph_def, 'model', 'test_model.pb', as_text=False)

# 查看模型中输入输出节点的名称和大小

# print(model.input.name, model.input.shape)

# print(model.output.name, model.output.shape)

 

此部分需要注意的是output_node_names为模型输出节点的名称,如果写错云运行的时候会出现"*** is not in graph"的错误,因此可以使用# print(model.input.name, model.input.shape)和 print(model.output.name, model.output.shape)来查看,这两个节点名称在后续的python和java中调用会用到.


2.python中调用pb文件


import tensorflow as tf

import numpy as np

with tf.Graph().as_default():

    output_graph_def = tf.GraphDef()

    # 读取保存的pb文件,将其解析成对应的GraphDef Protocol Buffer

    with open('/home/thm/PycharmProjects/code/Extraction/Attacks_detection/step/model.pb', "rb") as f:

        output_graph_def.ParseFromString(f.read())

        # graph_def:将graph中保存的图片加载到当前图中,包含要导入到默认图中的操作的GraphDef protocol

        # name:放在graph_def名称前面的前缀,但是并不适用于作为输入的名称,默认的是import

        _ = tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:

        init = tf.global_variables_initializer()

        sess.run(init)  # 初始化所有的变量

        a = np.array([6, 1082, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2])

        # 获取输入输出节点的名称

        input_x = sess.graph.get_tensor_by_name("dense_1_input:0")

        output = sess.graph.get_tensor_by_name("dense_3/Softmax:0")

        # 传入想要计算的参数,将input_x的值替换(feed使用tensorflow值临时替换一个操作的输入参数,从而替换原来的输出结果)

        result = sess.run(output, feed_dict={input_x: a.reshape(1, 15)})

        print(result)

        Class_dict = {'BENIGN': 0, 'syn_flood': 1, 'icmp_flood': 2, 'udp_flood': 3, 'sarfu': 4}

        species_dict = {v: k for k, v in Class_dict.items()}

        # 代码中v:k代表了v是key,k是value,而k,v则是表示key,value换个位置

        print("\nPredicted species is: ")

        print(species_dict[np.argmax(result)])

 

3.java中调用pb文件

首先要添加依赖关系

方法一: 右击项目/Maven/Add Dependency,在出现的对话框中添加依赖关系(代码中的groupid/artifactid/version和对话框中的三部分一一对应).

方法二:打开项目中的pox.xml文件,将下面的代码直接粘贴过去.

特别注意:导入的tensorflow包的version对应的是你自己tensroflow的版本

(在终端下查看tensorflow版本号

python

import tensorflow as tf

(如果输入该命令之后提示ModuleNotFoundError: No module named ‘tensorflow’,应该是没有在tensorflow环境下运行python,因此先激活tensorflow

source activate tensorflow

再运行python

tf.version)


<dependency>

<groupId>commons-io</groupId>

<artifactId>commons-io</artifactId>

<version>2.6</version>

</dependency>

<dependency>

<groupId>org.tensorflow</groupId>

<artifactId>libtensorflow</artifactId>

<version>1.10.0</version>

</dependency>

<dependency>

         <groupId>org.tensorflow</groupId>

         <artifactId>proto</artifactId>

         <version>1.10.0</version>

       </dependency>

       <dependency>

<groupId>org.tensorflow</groupId>

<artifactId>libtensorflow_jni</artifactId>

<version>1.10.0</version>

       </dependency>

 

image.png

import java.io.FileInputStream;

import java.io.IOException;

import java.nio.FloatBuffer;

import org.apache.commons.io.IOUtils;

import org.tensorflow.*;


public class Test_model {

    public static String PB_FILE_PATH = "pb模型保存的位置";

    public static String INPUT_TENSOR_NAME = "dense_1_input:0";//前面提到的输入输出节点的名称

    public static String OUTPUT_TENSOR_NAME = "dense_3/Softmax:0";

 

    public static void main(String[] args) throws IOException {

        try (Graph graph = new Graph()) {

            //导入图

            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(PB_FILE_PATH));

            graph.importGraphDef(graphBytes);

            float[] a = new float[]{6, 1082, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2};

            long[] shape = new long[]{1,15};

            Tensor<?> data = Tensor.create(shape, FloatBuffer.wrap(a));

            //根据图建立Session

            try (Session session = new Session(graph)) {

                //相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0})

                Tensor<?> out = session.runner()

                        .feed(INPUT_TENSOR_NAME, data)

                        .fetch(OUTPUT_TENSOR_NAME).run().get(0);

                System.out.println(out);

            }

        }

    }

}

 

4.java中在原有项目添加完依赖关系后可能会出现"Exception in thread “main” java.lang.NoClassDefFoundError:"

此问题的具体解决方法可以参考这篇博客,讲的挺详细的https://blog.csdn.net/lz6363/article/details/82561292



————————————————


                            版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

                        

原文链接:https://blog.csdn.net/weixin_44945845/article/details/108430308