分享

最新Java教程:在Java中使用便携式ONNX AI模型

 码农9527 2021-04-06

 在我们关于2020年使用便携式神经网络的系列文章中,您将了解如何在x64架构上安装ONNX并在Java中使用它。

    微软与Facebook和AWS共同开发了ONNX。ONNX格式和ONNXRuntime都得到了业界的支持,以确保所有重要的框架都能够将它们的图导出到ONNX,并且这些模型可以在任何硬件配置上运行。

    ONNXRuntime是一个用于运行已转换为ONNX格式的机器学习模型的引擎。传统的机器学习模型和深度学习模型(神经网络)都可以输出到ONNX格式。该运行时可以在Linux、Windows和Mac上运行,并且可以在各种芯片架构上运行。它还可以利用GPU和TPU等硬件加速器。不过,并不是每一种操作系统、芯片架构和加速器的组合都有安装包,所以如果您没有使用常见的组合,可能需要从源码中构建运行时。请查看ONNX运行时网站,获取所需组合的安装说明。本文将介绍如何在使用默认CPU的x64架构和使用GPU的x64架构上安装ONNXRuntime。

Java

    除了能够在多种硬件配置上运行,运行时还可以从大多数流行的编程语言中调用。本文的目的是展示如何在Java中使用ONNX运行时。我将展示如何安装onnxruntime包。安装ONNXRuntime后,我将把之前导出的MNIST模型加载到ONNXRuntime中,并使用它进行预测。

    安装和导入ONNX运行时系统

    在使用ONNX运行时之前,您需要为您的构建工具添加适当的依赖性。Maven资源库是为Maven和Gradle等各种工具设置ONNX运行时的良好来源。要在x64架构和默认CPU上使用运行时,请参考以下链接。https:///artifact/org.bytedeco/onnxruntime-platform

    要在x64架构的GPU上使用运行时,请使用以下链接。https:///artifact/org.bytedeco/onnxruntime-platform-gpu

    一旦安装了运行时,就可以通过下图所示的导入语句将其导入到你的Java代码文件中。导入TensorProto工具的导入语句将帮助我们为ONNX模型创建输入,它还将帮助解释ONNX模型的输出(预测)。

import ai.onnxruntime.OnnxMl.TensorProto;import ai.onnxruntime.OnnxMl.TensorProto.DataType;import ai.onnxruntime.OrtSession.Result;import ai.onnxruntime.OrtSession.SessionOptions;import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode;import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;

    加载ONNX模型

    下面的片段显示了如何将ONNX模型加载到以Java运行的ONNXRuntime中。这段代码创建了一个会话对象,可用于进行预测。这里使用的模型是从PyTorch导出的ONNX模型。

    这里有几件事值得注意。首先,您需要查询会话以获取其输入。这是通过会话的getInputInfo方法完成的。我们的MNIST模型只有一个输入参数:一个由784个浮点组成的数组,代表MNIST数据集中的一张图像。如果您的模型有多个输入参数,那么InputMetadata将为每个参数设置一个条目。

Utilities.LoadTensorData();
String modelPath = "pytorch_mnist.onnx";try (OrtSession session = env.createSession(modelPath, options)) {
   Map<String, NodeInfo> inputMetaMap = session.getInputInfo();
   Map<String, OnnxTensor> container = new HashMap<>();
   NodeInfo inputMeta = inputMetaMap.values().iterator().next();   float[] inputData = Utilities.ImageData[imageIndex];
   string label = Utilities.ImageLabels[imageIndex];
   System.out.println("Selected image is the number: " + label);   // this is the data for only one input tensor for this model
   Object tensorData =
            OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape());
   OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorData);
   container.put(inputMeta.getName(), inputTensor);   // Run code omitted for brevity.}

    上面的代码中没有显示的是读取原始MNIST图像并将每幅图像转换为784个浮动数组的实用程序。每个图像的标签也从MNIST数据集中读取,这样就可以确定预测的准确性。这段代码是标准的Java代码,但我们仍然鼓励你检查并使用它。如果您需要读取与MNIST数据集相似的图像,它将为您节省时间。

    使用ONNX运行时间进行预测。

    下面的功能显示了如何使用我们加载ONNX模型时创建的ONNX会话。

try (OrtSession session = env.createSession(modelPath, options)) {   // Load code not shown for brevity.

   // Run the inference
   try (OrtSession.Result results = session.run(container)) {      // Only iterates once
      for (Map.Entry<String, OnnxValue> r : results) {
         OnnxValue resultValue = r.getValue();
         OnnxTensor resultTensor = (OnnxTensor) resultValue;
         resultTensor.getValue()
         System.out.println("Output Name: {0}", r.Name);         int prediction = MaxProbability(resultTensor);
         System.out.println("Prediction: " + prediction.ToString());
}
   }
}

    大多数神经网络不直接返回预测。它们会返回每个输出类的概率列表。在我们MNIST模型的情况下,每个图像的返回值将是一个10个概率的列表。具有最高概率的条目就是预测。您可以做的一个有趣的测试是,当ONNX模型在创建模型的框架内运行时,比较ONNX模型返回的概率和原始模型返回的概率。理想情况下,模型格式和运行时的变化不应改变任何产生的概率。这将是一个很好的单元测试,每次模型发生变化时都会运行。

    总结和下一步

    在本文中,我简要介绍了ONNX运行时和ONNX格式。然后,我展示了如何在ONNX运行时使用Java加载和运行ONNX模型。

    本文的代码示例包含一个工作控制台应用程序,演示了这里所展示的所有技术。该代码示例是Github资源库的一部分,该资源库探讨了如何使用神经网络来预测MNIST数据集中发现的数字。具体来说,有一些样本展示了如何在Keras、PyTorch、TensorFlow1.0和TensorFlow2.0中创建神经网络。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多