npm 包@tensorflow/tfjs 使用教程

阅读时长 7 分钟读完

介绍

@tensorflow/tfjs 是 TensorFlow 的 JavaScript 版本,它提供了一套完整的机器学习算法和工具箱,可以在浏览器和 Node.js 环境中使用。本文将介绍如何使用它来构建前端机器学习应用。

环境要求

  • Node.js 8.0 及以上版本
  • 一个现代的浏览器(推荐使用 Chrome)

安装

在命令行中运行以下命令来安装@tensorflow/tfjs:

简单使用

加载模型

首先,我们需要加载一个预训练的模型,如以下代码所示:

这里,我们使用 loadLayersModel 方法来加载一个 TensorFlow.js 模型。传递给它一个指向模型结构的 JSON 文件的路径。在加载模型之前,需要先安装 @tensorflow/tfjs-node 库。

数据预处理

接下来,我们需要将输入数据与模型期望的输入形状进行匹配。例如,如果模型期望输入张量为形状 (height, width, channel),而我们将一个 28×28 的灰度图像作为输入数据,则需要将其重新塑造为形状 (28, 28, 1)。这可以通过以下代码完成:

推理

最后一步是对输入数据进行推理,如以下代码所示:

其中,model.predict 方法将输入数据传递给模型,并返回输出张量。在本例中,我们假设模型产生单个浮点数值作为输出,因此我们可以使用 dataSync 方法将输出张量转换为 JavaScript 数组。

示例

现在,我们将深入探讨一个著名的前端机器学习示例 - 手写数字识别,让您更加深入了解如何使用 TensorFlow.js 来构建完整的前端机器学习应用程序。

数据集

我们将使用 CIFAR-10 数据集。这是一个用于机器学习研究的图像分类数据集,其中包含来自 10 个不同类别的 60000 张 32x32 彩色图像。

我们可以使用 node-cifar10 包来加载并处理 CIFAR-10 数据集。

-- -------------------- ---- -------
----- ----- - ------------------------
----- -- - ----------------------------

----- ----------- - ----- -- -- - 
    ----- - ----------- - - ----- ---------------------------- ------
    ----- - ---------- - - ----- --------------------------- ------
 
    ----- ----------- --------- - ------------- ---------------------- -- -
        ----- ------ - ------------ -- ---------
        ----- -- - --------------------- -- ----------------------------------
        ----- -- - ----------------------------- --------- ----
        ------ - --- -- --
    ---
 
    ------ - ---------- -------- --
--

此处我们从 node-cifar10 包中加载图像,然后将它们转化为 TensorFlow.js 所需的张量:

  • xs 张量包含所有的图像数据
  • ys 张量包含对应的标签

创建模型

我们将使用一个简单的卷积神经网络 (CNN) 来处理 CIFAR-10 数据集。以下是一个示例:

-- -------------------- ---- -------
----- ----------- - ------------ -- -
    ----- ----- - ----------------
    ----------------------------
        ----------- ---- --- ---
        ----------- --
        -------- ---
        ----------- ------
    ----
    ---------------------------------- --------- -- -------- - ----
    -------------------------------
    ---------------------------
        ------ ---
        ----------- ------
    ----
    ---------------------------
        ------ -----------
        ----------- ---------
    ----
 
    ------ ------
-

这是一个简单的 CNN,其中包含一组 3×3 的滤波器,最大池化层和两个全连接层。我们选择 softmax 激活函数作为最后一层的激活函数,适用于分类任务。

训练模型

现在,我们需要使用训练数据集训练我们的模型,使用测试集对模型进行验证。

-- -------------------- ---- -------
----- ---------- - ----- ----------- --------- -- -
    ----- ---------- - ---
    ----- ------ - ---
    ----- --------- - ----
 
    ----- ----- - ------------------------
 
    ---------------
        ---------- ----------------
        ----- --------------------------
        -------- ------------
    ---
 
    ----- - --- ------- --- ------ - - ----------
    ----- - --- ------ --- ----- - - ---------
 
    ----- ----------------- ------- -
        -------
        ----------
        --------------- ------- -------
        -------- ----
    ---
 
    ----- ------ - --------------------- ------ - --------- ---
    --------------------
 
    ------ ------
--

在此代码中,我们使用了 fit 方法对 CNN 进行训练,使用 evaluate 方法对 CNN 进行评估。

评估模型

评估训练后的模型并使用测试数据集上训练的模型进行预测。

-- -------------------- ---- -------
----- ------------- - ----- ------- --------- -- -
    ----- - --- ------ --- ----- - - ---------
 
    ----- ----------- - ---------------------
 
    ----- ----- - ---------------------------------------
    ----- ----- - ---------------------------------------------
 
    ------ - ------ ----- --
-

在此代码中,我们使用 argMax 方法找到最可能属于每个类的类别并将其转换为 JavaScript 数组。最后,我们可以使用这些数组计算模型的准确性。

结论

@tensorflow/tfjs 是一个非常强大的 JavaScript 库,提供了一套完整的机器学习算法和工具箱,可以直接在浏览器和 Node.js 环境中使用。在本文中,我们介绍了如何使用@tensorflow/tfjs 来构建前端机器学习应用程序,包括模型加载、数据预处理、模型推理以及构建一个手写数字识别实例。希望您现在对 TensorFlow.js 有更深入的了解。

来源:JavaScript中文网 ,转载请注明来源 https://www.javascriptcn.com/post/tensorflow-tfjs