Maven依赖
- maven加载tensorflow依赖库时可能会失败,因为tensorflow依赖了 libtensorflow_jni,有90多兆,若失败则手动下载放到本地的Maven仓库即可。
- 注:本次依赖使用的是Tensorflow1.x 的依赖,Tensorflow2.x 是 另外一个。
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.13.1</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>
代码:加载模型,执行计算图
package com.gzw.javatensorflow1131model;
import org.apache.commons.io.IOUtils;
import org.tensorflow.*;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
public class LoadModel
{
public static void main(String[] args) throws IOException
{
// 读取模型文件
byte[] graphBytes = IOUtils.toByteArray( new FileInputStream("src/main/java/my_freeze_graph.test2"));
//将模型文件转换为Tensorflow计算图
Graph graph = new Graph();
graph.importGraphDef( graphBytes );
//遍历所有节点
HashMap<String, Output<?>> tensorMap = new HashMap<>();
Iterator<Operation> operationList = graph.operations();
while( operationList.hasNext() )
{
Operation operation = operationList.next();
tensorMap.put( operation.name(), operation.output(0) );
}
System.out.println( tensorMap );
//判断目标节点是否齐全
if( !tensorMap.containsKey( "input_a" )
|| !tensorMap.containsKey( "input_b" )
|| !tensorMap.containsKey( "op_add" ) )
return;
//运行计算图
Session session = new Session( graph );
Tensor op_add_result = session.runner()
.feed( "input_a", Tensor.create( 2.0f ) )
.feed( "input_b", Tensor.create( 5.0f ) )
.fetch( tensorMap.get( "op_add" ))
.run()
.get(0);
System.out.println( "op_add_result: " + op_add_result.floatValue() );
}
}
执行结果