0.前言
随着深度模型的普及,线上越来越多的模型换成了深度模型,与此对应的线上模型的部署与调用方式也会发生变化。下面我们就来介绍一下分别用python代码与java代码调用训练好的模型。
1.模型训练
首先我们训练一个简单的模型 y = 3 x + 0.1 y = 3x + 0.1y=3x+0.1
#!/usr/bin/env python
# encoding: utf-8
"""
@author: wanglei
@time: 2020/7/30 上午9:25
"""
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util
train_X = np.linspace(-1, 1, 100)
train_Y = 3*train_X + 0.1
X = tf.placeholder("float",name='X')
Y = tf.placeholder("float",name='Y')
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
z = tf.multiply(X, W) + b
op = tf.add(tf.multiply(X, W),b,name='results')
cost = tf.reduce_mean(tf.square(Y – z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
# 定义参数
training_epochs = 20
display_step = 2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
plotdata = {'batchsize': [], 'loss': []}
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X:x, Y:y})
if epoch % display_step == 0:
loss = sess.run(cost, {X: train_X, Y: train_Y})
print("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b))
saver.save(sess, 'mymodels/first')
print("cost =", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["results"])
with tf.gfile.FastGFile("mymodels/first.pb", mode='wb') as f:
f.write(const_graph.SerializeToString())
print("Finished!")
上面的代码,会最终得到一个first.pd的模型。训练输出的结果如下
Epoch: 1 cost= 1.2481742 W= [1.2849432] b= [0.59788895]
Epoch: 3 cost= 0.10257308 W= [2.539106] b= [0.27416894]
Epoch: 5 cost= 0.006951133 W= [2.8805332] b= [0.1457994]
Epoch: 7 cost= 0.00046514577 W= [2.9691048] b= [0.11185519]
Epoch: 9 cost= 3.1108106e-05 W= [2.9920104] b= [0.10306594]
Epoch: 11 cost= 2.0804118e-06 W= [2.9979339] b= [0.10079286]
Epoch: 13 cost= 1.3938015e-07 W= [2.9994652] b= [0.10020526]
Epoch: 15 cost= 9.405541e-09 W= [2.999861] b= [0.10005322]
Epoch: 17 cost= 6.447206e-10 W= [2.9999638] b= [0.10001406]
Epoch: 19 cost= 5.3264854e-11 W= [2.9999897] b= [0.10000414]
cost = 2.6606464e-11 W= [2.9999924] b= [0.10000262]
最终得到的参数为W = 2.9999924 , b = 0.10000262 W=2.9999924, b=0.10000262W=2.9999924,b=0.10000262,与一次函数y = 3 x + 0.1 y=3x+0.1y=3x+0.1一致。
2.python代码调用模型
用python代码调用上面训练好的模型,示例如下
#!/usr/bin/env python
# encoding: utf-8
"""
@author: wanglei
@time: 2020/7/30 下午2:39
"""
import tensorflow as tf
from tensorflow.python.platform import gfile
sess = tf.Session()
with gfile.FastGFile('mymodels/first.pb', 'rb') as f:
graph = tf.GraphDef()
graph.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph, name='')
sess.run(tf.global_variables_initializer())
print(sess.run('weight:0'))
print(sess.run('bias:0'))
input_x = sess.graph.get_tensor_by_name('X:0')
op = sess.graph.get_tensor_by_name('results:0')
ret = sess.run(op, feed_dict={input_x: 2})
print("ret is: ", ret)
运行上面的代码,输出如下
[2.9999924]
[0.10000262]
ret is: [6.0999875]
3.java API调用
使用java解析调用模型的时候,需要加入tensorflow的官方依赖
<properties>
<grpc.version>1.20.0</grpc.version>
<tensorflow.version>1.13.1</tensorflow.version>
</properties>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>${tensorflow.version}</version>
</dependency>
因为我们代码里会使用到common-io里面的代码,所以加入相应的依赖
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.4</version>
</dependency>
然后开始解析模型并打分
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.FileInputStream;
/**
* author: wanglei
* create: 2020-07-30
*/
public class Demo1 {
public static void test() throws Exception {
Graph graph = new Graph();
String filepath = "xxx/mymodels/first.pb";
// 模型的图结构
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(filepath));
graph.importGraphDef(graphBytes);
// 根据图结构建议sess
Session sess = new Session(graph);
Tensor X = Tensor.create(2.0f);
Tensor result = sess.runner().feed("X", X).fetch("results").run().get(0);
System.out.println(result);
float[] vector = new float[1];
result.copyTo(vector);
System.out.println("vector[0] is: " + vector[0]);
}
public static void main(String[] args) throws Exception {
test();
}
}
代码的输出为
FLOAT tensor with shape [1]
vector[0] is: 6.0999875
————————————————
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/bitcarmanlee/article/details/107692881