介绍
@tensorflow/tfjs
是 TensorFlow 的 JavaScript 版本,它提供了一套完整的机器学习算法和工具箱,可以在浏览器和 Node.js 环境中使用。本文将介绍如何使用它来构建前端机器学习应用。
环境要求
- Node.js 8.0 及以上版本
- 一个现代的浏览器(推荐使用 Chrome)
安装
在命令行中运行以下命令来安装@tensorflow/tfjs
:
npm install @tensorflow/tfjs
简单使用
加载模型
首先,我们需要加载一个预训练的模型,如以下代码所示:
import * as tf from '@tensorflow/tfjs'; const model = await tf.loadLayersModel('path/to/model.json');
这里,我们使用 loadLayersModel
方法来加载一个 TensorFlow.js 模型。传递给它一个指向模型结构的 JSON 文件的路径。在加载模型之前,需要先安装 @tensorflow/tfjs-node
库。
数据预处理
接下来,我们需要将输入数据与模型期望的输入形状进行匹配。例如,如果模型期望输入张量为形状 (height, width, channel),而我们将一个 28×28 的灰度图像作为输入数据,则需要将其重新塑造为形状 (28, 28, 1)。这可以通过以下代码完成:
const image = /* 28x28 grayscale image as a 1D array of floats */; const input = tf.tensor4d(image, [1, 28, 28, 1]);
推理
最后一步是对输入数据进行推理,如以下代码所示:
const output = model.predict(input); const prediction = output.dataSync();
其中,model.predict
方法将输入数据传递给模型,并返回输出张量。在本例中,我们假设模型产生单个浮点数值作为输出,因此我们可以使用 dataSync
方法将输出张量转换为 JavaScript 数组。
示例
现在,我们将深入探讨一个著名的前端机器学习示例 - 手写数字识别,让您更加深入了解如何使用 TensorFlow.js 来构建完整的前端机器学习应用程序。
数据集
我们将使用 CIFAR-10 数据集。这是一个用于机器学习研究的图像分类数据集,其中包含来自 10 个不同类别的 60000 张 32x32 彩色图像。
我们可以使用 node-cifar10
包来加载并处理 CIFAR-10 数据集。
-- -------------------- ---- ------- ----- ----- - ------------------------ ----- -- - ---------------------------- ----- ----------- - ----- -- -- - ----- - ----------- - - ----- ---------------------------- ------ ----- - ---------- - - ----- --------------------------- ------ ----- ----------- --------- - ------------- ---------------------- -- - ----- ------ - ------------ -- --------- ----- -- - --------------------- -- ---------------------------------- ----- -- - ----------------------------- --------- ---- ------ - --- -- -- --- ------ - ---------- -------- -- --
此处我们从 node-cifar10
包中加载图像,然后将它们转化为 TensorFlow.js 所需的张量:
xs
张量包含所有的图像数据ys
张量包含对应的标签
创建模型
我们将使用一个简单的卷积神经网络 (CNN) 来处理 CIFAR-10 数据集。以下是一个示例:
-- -------------------- ---- ------- ----- ----------- - ------------ -- - ----- ----- - ---------------- ---------------------------- ----------- ---- --- --- ----------- -- -------- --- ----------- ------ ---- ---------------------------------- --------- -- -------- - ---- ------------------------------- --------------------------- ------ --- ----------- ------ ---- --------------------------- ------ ----------- ----------- --------- ---- ------ ------ -
这是一个简单的 CNN,其中包含一组 3×3 的滤波器,最大池化层和两个全连接层。我们选择 softmax 激活函数作为最后一层的激活函数,适用于分类任务。
训练模型
现在,我们需要使用训练数据集训练我们的模型,使用测试集对模型进行验证。
-- -------------------- ---- ------- ----- ---------- - ----- ----------- --------- -- - ----- ---------- - --- ----- ------ - --- ----- --------- - ---- ----- ----- - ------------------------ --------------- ---------- ---------------- ----- -------------------------- -------- ------------ --- ----- - --- ------- --- ------ - - ---------- ----- - --- ------ --- ----- - - --------- ----- ----------------- ------- - ------- ---------- --------------- ------- ------- -------- ---- --- ----- ------ - --------------------- ------ - --------- --- -------------------- ------ ------ --
在此代码中,我们使用了 fit
方法对 CNN 进行训练,使用 evaluate
方法对 CNN 进行评估。
评估模型
评估训练后的模型并使用测试数据集上训练的模型进行预测。
-- -------------------- ---- ------- ----- ------------- - ----- ------- --------- -- - ----- - --- ------ --- ----- - - --------- ----- ----------- - --------------------- ----- ----- - --------------------------------------- ----- ----- - --------------------------------------------- ------ - ------ ----- -- -
在此代码中,我们使用 argMax
方法找到最可能属于每个类的类别并将其转换为 JavaScript 数组。最后,我们可以使用这些数组计算模型的准确性。
结论
@tensorflow/tfjs
是一个非常强大的 JavaScript 库,提供了一套完整的机器学习算法和工具箱,可以直接在浏览器和 Node.js 环境中使用。在本文中,我们介绍了如何使用@tensorflow/tfjs
来构建前端机器学习应用程序,包括模型加载、数据预处理、模型推理以及构建一个手写数字识别实例。希望您现在对 TensorFlow.js 有更深入的了解。
来源:JavaScript中文网 ,转载请注明来源 https://www.javascriptcn.com/post/tensorflow-tfjs