DJ王大海

Back

PyTorch 框架 + MNIST 数据集手写数字识别Blur image

PyTorch#

  • 由 Facebook 人工智能研究小组开发的机器学习框架
  • 代码简洁、符合人类思维、易上手
  • 少量代码即可完成机器学习任务

MNIST 数据集#

  • 手写数字图片 7 万张
  • 训练集 6 万张 + 测试集 1 万张
  • 每张图片大小为 28*28 像素
  • 灰度值范围 0-255
  • 每张图片配有一个标记,既这张图片的真实值

MNIST示例

设计神经网络#

设有这样一张图片,大小为 5*5,内容为数字 7:

  1. 将像素重新排列成为一维阵列,构成神经网络的第零层节点,设第0层节点中的数值为 X00、X10、X20、……、X240

  2. 计算第一层节点的数值,例如 X01 = ∑ai,10 * Xi0 + bi,10

    字母 i 表示前一层的节点序号

  3. 同理构建第二、三、四层节点,节点传播公式也扩展为Xk+11 = ∑ai,1k * Xik + bi,1k

    字母 kk+1 表示网络层数

  4. 最后一层节点即为输出层,输出层有十个节点,每个节点对应 0-9 的数字,节点的数值就是该节点对应数字的概率,表示这张图片是某数字的概率

神经网络示例

输出节点归一化#

由于节点传播公式中的 ab 是任意的,所以输出层节点上的数值也应该是任意的,但概率的取值范围是 0-1,且十个概率加在一起应该等于 1。

所以将输出节点的数值进行如下处理:

  1. 用自然常数 e 对 X0、X1、X2、……、X9 进行一次指数运算,将变成每个数都变为正数,得到 eX0、eX1、X2、……、X9
  2. 将上一步得到的十个数字求和,得到 ∑eXi
  3. 用第二步中得到的数字作分母,用eX0、eX1、X2、……、X9 除以 ∑eXi

这样便得到了在 0-1 范围内,且总和为 1 的数组,以上便是 softmax 归一化

训练#

上面的步骤得到了一组看上去像是概率的数字,但并不具备概率的意义,因为任意取值得到的 ab 并不一定会使得到的数组接近真实概率,还需要经过训练才能让其具备识别的能力。

所以,可以把神经网络理解为一个函数。训练的过程,就是调整函数中的参数的过程。在这个例子里面,我们每次只用了一张图片来调整网络参数,也可以每次将几个图片打包为一个 batch,一起发给神经网络来调整参数。

上面的例子中,所有节点间的计算都是线性的,所以网络的总输入和总输出也是线性的。但对于很多问题,输入输出间存在非线性。所以我们会在在节点传播公式中套上一个非线性函数 f(),称为激活函数。常见的有对数函数、双曲函数、整流函数等等。

编写代码#

这段代码用到了 numpy、torch、torchvision、matplotlib 四个库,用下面的命令进行安装:

pip install numpy torch torchvision matplotlib
bash

下面是项目代码:

PyTorch 框架 + MNIST 数据集手写数字识别
https://astro-pure.js.org/blog/2023/pytorch-mnist
Author 小岛秀儿
Published at 2023年10月19日
Comment seems to stuck. Try to refresh?✨