Maven依赖

  • maven加载tensorflow依赖库时可能会失败,因为tensorflow依赖了 ​​ ​libtensorflow_jni​​​,有90多兆,若失败则手动下载放到本地的Maven仓库即可。
  • 注:本次依赖使用的是Tensorflow1.x 的依赖,Tensorflow2.x 是​ ​另外一个​​。
<dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.13.1</version>
        </dependency>

        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.6</version>
        </dependency>
 

代码:加载模型,执行计算图

package com.gzw.javatensorflow1131model;

import org.apache.commons.io.IOUtils;
import org.tensorflow.*;

import java.io.FileInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;

public class LoadModel
{
    public static void main(String[] args) throws IOException
    {
        // 读取模型文件
        byte[] graphBytes = IOUtils.toByteArray( new FileInputStream("src/main/java/my_freeze_graph.test2"));

        //将模型文件转换为Tensorflow计算图
        Graph graph = new Graph();
        graph.importGraphDef( graphBytes );

        //遍历所有节点
        HashMap<String, Output<?>> tensorMap = new HashMap<>();
        Iterator<Operation> operationList = graph.operations();
        while( operationList.hasNext() )
        {
            Operation operation = operationList.next();
            tensorMap.put( operation.name(), operation.output(0) );
        }
        System.out.println( tensorMap );

        //判断目标节点是否齐全
        if( !tensorMap.containsKey( "input_a" )
                || !tensorMap.containsKey( "input_b" )
                || !tensorMap.containsKey( "op_add" ) )
            return;

        //运行计算图
        Session session = new Session( graph );
        Tensor op_add_result = session.runner()
                .feed( "input_a", Tensor.create( 2.0f ) )
                .feed( "input_b", Tensor.create( 5.0f ) )
                .fetch( tensorMap.get( "op_add" ))
                .run()
                .get(0);

        System.out.println( "op_add_result: " + op_add_result.floatValue() );
    }
}
 

执行结果