TensorFlow.js是一个基于deeplearn.js构建的强大而灵活的Javascript机器学习库,它可直接在浏览器上创建深度学习模块。使用它可以在浏览器上创建CNN(卷积神经网络)、RNN(循环神经网络)等等,且可以使用终端的GPU处理能力训练这些模型。接下来我们将学习如何建立一个简单的“可学习机器”——基于 TensorFlow.js 的迁移学习图像分类器。 加载 TensorFlow.js 和MobileNet 模型 在编辑器中创建一个HTML文件,命名为index.html,添加以下内容。 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <title></title> <!--加载最新版本的TensorFlow.js --> <script src="https:///@tensorflow/tfjs"></script> <script src="https:///@tensorflow-models/mobilenet"></script> </head> <body> <div id="console></div> <!--添加一个用于测试的图像--> <img id="img" crossOrigin src="https://i./JlUvsxa.jpg" width=227 height=227/> <!-- 加载 index.js 在内容页之后--> <script src="js/index.js"></script> </body> </html> |
注意:在img里请使用有用的图片地址 在浏览器中设置 MobileNet 用于预测 在代码编辑器中打开/创建index.js 文件,添加以下代码: let net; async function app(){ console.log('Loading mobilenet..'); // 加载模型 net = await mobilenet.load(); console.log('Sucessfully loaded model'); // 通过模型预测图像 const imgEl = document.getElementById('img'); const result = await net.classify(imgEl); console.log(result); } app(); |
在浏览器中测试 MobileNet 的预测 运行index.html文件,调出JavaScript控制台,你将看见一张狗的照片,而这就是MobileNet 的预测结果。 通过网络摄像头图像在浏览器中执行 MobileNet 预测 接下来,我们来设置网络摄像头来预测由网络摄像头传输的图像。 现在,让我们让它更具交互性和实时性。让我们设置网络摄像头来预测由网络摄像头传输的图像。 首先要设置网络摄像头的视频元素。打开 index.html 文件,在 <body> 部分中添加如下行,并删除我们用于加载狗图像的<img> 标签。 <video autoplay playsinline muted id="webcam" width="224" height="224"></video> |
打开 index.js 文件并将webcamElement 添加到文件的最顶部。 const webcamElement = document.getElementById('webcam'); |
在同一个 index.js 文件中,在调用 “app()” 函数之前添加网络摄像头的设置函数: async function setupWebcam() { return new Promise((resolve, reject) => { const navigatorAny = navigator; navigator.getUserMedia = navigator.getUserMedia || navigatorAny.webkitGetUserMedia || navigatorAny.mozGetUserMedia || navigatorAny.msGetUserMedia; if (navigator.getUserMedia) { navigator.getUserMedia({video: true}, stream=> { webcamElement.srcObject = stream; webcamElement.addEventListener('loadeddata', () => resolve(), false); }, error=> reject()); }else{ reject(); } }); } |
在之前添加的 app() 函数中,你可以删除通过图像预测的部分,用一个无限循环,通过网络摄像头预测代替。 async function app(){ console.log('Loading mobilenet..'); // 加载模型 net = await mobilenet.load(); console.log('Sucessfully loaded model'); await setupWebcam(); while(true){ const result = await net.classify(webcamElement); document.getElementById('console').innerText =` prediction: ${result[0].className}\n probability: ${result[0].probability} `; // 给自己一些喘息的空间 // 等待下一个动画帧开始 await tf.nextFrame(); } } |
如果你在网页中打开控制台,现在你应该会看到 MobileNet 的预测和网络摄像头收集到的每一帧图像。 在 MobileNet 预测的基础上添加一个自定义的分类器 现在,让我们把它变得更加实用。我们使用网络摄像头动态创建一个自定义的 3 对象的分类器。我们将通过 MobileNet 进行分类,但这次我们将使用特定网络摄像头图像在模型的内部表示(激活值)来进行分类。 我们将使用一个叫做 "K-Nearest NeighborsClassifier" 的模块,他将有效的让我们把摄像头采集的图像(实际上是 MobileNet 中的激活值)分成不同的类别,当用户要求做出预测时,我们只需选择拥有与待预测图像最相似的激活值的类即可。 在 index.html 的<head> 标签的末尾添加 KNN 分类器的导入(你仍然需要 MobileNet,所以不要删除它的导入): <script src="https:///@tensorflow-models/knn-classifier"></script> |
在 index.html 视频部分下面添加三个按钮。这些按钮将用于向模型添加训练图像。 <button id="class-a">add a</button> <button id="class-b">add a</button> <button id="class-c">add a</button> |
在 index.js 的顶部,创建一个分类器 const classifier = knnClassifier.create(); |
更新 app 函数; async function app(){ console.log('Loading mobilenet..'); //losd the model net = await mobilenet.load(); console.log('Sucessfully loaded model'); await setupWebcam(); //从网络摄像头中读取图像并将其与特定类关联 const addExample classId=>{ // 获取 MobileNet 中间的 'conv_preds' 的激活值 // 并将其传递给 KNN 分类器 const activation = net.infer(webcamElement, 'conv_preds'); // 将中间激活值传递给分类器 classifier.addExample(activation, classId); }; // 单击该按钮是,为该类添加一个实例 document.getElementById("class-a").addEventListener('click',()=>addExample(0)); document.getElementById("class-b").addEventListener('click',()=>addExample(1)); document.getElementById("class-c").addEventListener('click',()=>addExample(2)); while(true){ if(classifier.getNumClasses() > 0) { // 获取 MobileNet 在网络摄像头中图像上的激活值 const activation = net.infer(webcamElement, 'conv_preds'); // 从分类器模块上获取最可能的类 const result = await classifier.predictClass(activation); const classes = ['A', 'B', 'C']; document.getElementById('console').innerText = ` prediction: ${classes[result.classIndex]}\n probability: ${result.confidences[result.classIndex]} `; } await tf.nextFrame(); } } |
当你加载 index.html 页面时,你可以使用常用对象或面部表情/手势为这三个类中的每一个类捕获图像。每次单击其中一个 "Add" 按钮,就会向该类添加一个图像作为训练实例。当你这样做的时候,模型会继续预测网络摄像头的图像,并实时显示结果。 注意:在这里可能会报错,出现:Uncaught (in promise)TypeError: Failed to fetch。 这个错误提示是网络连接超时的意思,解决办法如下:
|