0.前言

随着深度模型的普及,线上越来越多的模型换成了深度模型,与此对应的线上模型的部署与调用方式也会发生变化。下面我们就来介绍一下分别用python代码与java代码调用训练好的模型。


1.模型训练

首先我们训练一个简单的模型 y = 3 x + 0.1 y = 3x + 0.1y=3x+0.1


#!/usr/bin/env python

# encoding: utf-8


"""

@author: wanglei

@time: 2020/7/30 上午9:25

"""


import tensorflow as tf

import numpy as np

from tensorflow.python.framework import graph_util


train_X = np.linspace(-1, 1, 100)

train_Y = 3*train_X + 0.1



X = tf.placeholder("float",name='X')

Y = tf.placeholder("float",name='Y')

W = tf.Variable(tf.random_normal([1]), name="weight")

b = tf.Variable(tf.zeros([1]), name="bias")


z = tf.multiply(X, W) + b

op = tf.add(tf.multiply(X, W),b,name='results')


cost = tf.reduce_mean(tf.square(Y – z))

learning_rate = 0.01

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()


# 定义参数

training_epochs = 20

display_step = 2


saver = tf.train.Saver()


with tf.Session() as sess:

    sess.run(init)

    plotdata = {'batchsize': [], 'loss': []}


    for epoch in range(training_epochs):

        for (x, y) in zip(train_X, train_Y):

            sess.run(optimizer, feed_dict={X:x, Y:y})


        if epoch % display_step == 0:

            loss = sess.run(cost, {X: train_X, Y: train_Y})

            print("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b))


    saver.save(sess, 'mymodels/first')

    print("cost =", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))

    const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["results"])


    with tf.gfile.FastGFile("mymodels/first.pb", mode='wb') as f:

        f.write(const_graph.SerializeToString())


print("Finished!")

 

上面的代码,会最终得到一个first.pd的模型。训练输出的结果如下


Epoch: 1 cost= 1.2481742 W= [1.2849432] b= [0.59788895]

Epoch: 3 cost= 0.10257308 W= [2.539106] b= [0.27416894]

Epoch: 5 cost= 0.006951133 W= [2.8805332] b= [0.1457994]

Epoch: 7 cost= 0.00046514577 W= [2.9691048] b= [0.11185519]

Epoch: 9 cost= 3.1108106e-05 W= [2.9920104] b= [0.10306594]

Epoch: 11 cost= 2.0804118e-06 W= [2.9979339] b= [0.10079286]

Epoch: 13 cost= 1.3938015e-07 W= [2.9994652] b= [0.10020526]

Epoch: 15 cost= 9.405541e-09 W= [2.999861] b= [0.10005322]

Epoch: 17 cost= 6.447206e-10 W= [2.9999638] b= [0.10001406]

Epoch: 19 cost= 5.3264854e-11 W= [2.9999897] b= [0.10000414]

cost = 2.6606464e-11 W= [2.9999924] b= [0.10000262]

 

最终得到的参数为W = 2.9999924 , b = 0.10000262 W=2.9999924, b=0.10000262W=2.9999924,b=0.10000262,与一次函数y = 3 x + 0.1 y=3x+0.1y=3x+0.1一致。


2.python代码调用模型

用python代码调用上面训练好的模型,示例如下


#!/usr/bin/env python

# encoding: utf-8


"""

@author: wanglei

@time: 2020/7/30 下午2:39

"""


import tensorflow as tf

from tensorflow.python.platform import gfile


sess = tf.Session()


with gfile.FastGFile('mymodels/first.pb', 'rb') as f:

    graph = tf.GraphDef()

    graph.ParseFromString(f.read())

    sess.graph.as_default()

    tf.import_graph_def(graph, name='')


    sess.run(tf.global_variables_initializer())

    print(sess.run('weight:0'))

    print(sess.run('bias:0'))


    input_x = sess.graph.get_tensor_by_name('X:0')

    op = sess.graph.get_tensor_by_name('results:0')

    ret = sess.run(op, feed_dict={input_x: 2})

    print("ret is: ", ret)

 

运行上面的代码,输出如下


[2.9999924]

[0.10000262]

ret is:  [6.0999875]

 

3.java API调用

使用java解析调用模型的时候,需要加入tensorflow的官方依赖


    <properties>

        <grpc.version>1.20.0</grpc.version>

        <tensorflow.version>1.13.1</tensorflow.version>

    </properties>


        <dependency>

            <groupId>org.tensorflow</groupId>

            <artifactId>tensorflow</artifactId>

            <version>${tensorflow.version}</version>

        </dependency>


 

因为我们代码里会使用到common-io里面的代码,所以加入相应的依赖


        <dependency>

            <groupId>commons-io</groupId>

            <artifactId>commons-io</artifactId>

            <version>2.4</version>

        </dependency>

 

然后开始解析模型并打分


import org.apache.commons.io.IOUtils;

import org.tensorflow.Graph;

import org.tensorflow.Session;

import org.tensorflow.Tensor;


import java.io.FileInputStream;


/**

 * author: wanglei

 * create: 2020-07-30

 */

public class Demo1 {


    public static void test() throws Exception {

        Graph graph = new Graph();

        String filepath = "xxx/mymodels/first.pb";

        // 模型的图结构

        byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(filepath));

        graph.importGraphDef(graphBytes);


// 根据图结构建议sess

        Session sess = new Session(graph);

        Tensor X = Tensor.create(2.0f);

        Tensor result = sess.runner().feed("X", X).fetch("results").run().get(0);

        System.out.println(result);

        float[] vector = new float[1];

        result.copyTo(vector);

        System.out.println("vector[0] is: " + vector[0]);

    }


    public static void main(String[] args) throws Exception {

        test();

    }

}

 

代码的输出为


FLOAT tensor with shape [1]

vector[0] is: 6.0999875

————————————————


                            版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

                        

原文链接:https://blog.csdn.net/bitcarmanlee/article/details/107692881