当前位置: 首页 > news >正文

PyTorch实战:实现MNIST手写数字识别


前言

PyTorch可以说是三大主流框架中最适合初学者学习的了,相较于其他主流框架,PyTorch的简单易用性使其成为初学者们的首选。这样我想要强调的一点是,框架可以类比为编程语言,仅为我们实现项目效果的工具,也就是我们造车使用的轮子,我们重点需要的是理解如何使用Torch去实现功能而不要过度在意轮子是要怎么做出来的,那样会牵扯我们太多学习时间。以后就出一系列专门细解深度学习框架的文章,但是那是较后期我们对深度学习的理论知识和实践操作都比较熟悉才好开始学习,现阶段我们最需要的是学会如何使用这些工具。

深度学习的内容不是那么好掌握的,包含大量的数学理论知识以及大量的计算公式原理需要推理。且如果不进行实际操作很难够理解我们写的代码究极在神经网络计算框架中代表什么作用。不过我会尽可能将知识简化,转换为我们比较熟悉的内容,我将尽力让大家了解并熟悉神经网络框架,保证能够理解通畅以及推演顺利的条件之下,尽量不使用过多的数学公式和专业理论知识。以一篇文章快速了解并实现该算法,以效率最高的方式熟练这些知识。


博主专注数据建模四年,参与过大大小小数十来次数学建模,理解各类模型原理以及每种模型的建模流程和各类题目分析方法。此专栏的目的就是为了让零基础快速使用各类数学模型、机器学习和深度学习以及代码,每一篇文章都包含实战项目以及可运行代码。博主紧跟各类数模比赛,每场数模竞赛博主都会将最新的思路和代码写进此专栏以及详细思路和完全代码。希望有需求的小伙伴不要错过笔者精心打造的专栏。

一文速学-数学建模常用模型


一、数据集加载

MNIST(Modified National Institute of Standards and Technology)是一个手写数字数据集,通常用于训练各种图像处理系统。

它包含了大量的手写数字图像,这些数字从0到9。每个图像都是一个灰度图像,大小为28x28像素,表示了一个手写数字。

MNIST数据集分成两部分:训练集和测试集。训练集通常包含60,000张图像,用于训练模型。测试集包含10,000张图像,用于评估模型的性能。

MNIST数据集是一个非常受欢迎的数据集,被用于测试和验证各种机器学习和深度学习模型,特别是在图像识别任务中。大家可以直接访问官网下载或者是在程序中使用torchvision下载数据集。

官网:THE MNIST DATABASE

一共4个文件,训练集、训练集标签、测试集、测试集标签:

文件名称大小内容
train-labels-idx1-ubyte.gz9,681 kb55000张训练集,5000张验证集
train-labels-idx1-ubyte.gz29 kb训练集图片对应的标签
t10k-images-idx3-ubyte.gz1,611 kb10000张测试集
t10k-labels-idx1-ubyte.gz5 kb测试集图片对应的标签

 程序加载MNIST数据集:

from torch.utils.data import DataLoader
import torchvision.datasets as dsetstransform = transforms.Compose([transforms.Grayscale(num_output_channels=1),  # 将图像转为灰度transforms.ToTensor(),  # 将图像转为张量transforms.Normalize((0.1307,), (0.3081,))
])#MNIST dataset
train_dataset = dsets.MNIST(root = '/ml/pymnist',  #选择数据的根目录train = True,  #选择训练集transform = transform,  #不考虑使用任何数据预处理download = True  #从网络上下载图片)
test_dataset = dsets.MNIST(root = '/ml/pymnist',#选择数据的根目录train = False,#选择测试集transform = transform, #不考虑使用任何数据预处理download = True #从网络上下载图片)
#加载数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size = batch_size,shuffle = True #将数据打乱)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size = batch_size,shuffle = True)

图片展示:

import matplotlib.pyplot as plt
digit = train_dataset.train_data[0]
plt.imshow(digit,cmap=plt.cm.binary,interpolation='none')
plt.title("Labels: {}".format(train_dataset.train_labels[0]))
plt.show()

 之后需要切分数据集,分为训练集和测试集,MNIST数据集已经做好了直接使用就好了:

print("train_data:",train_dataset.train_data.size())
print("train_labels:",train_dataset.train_labels.size())
print("test_data:",test_dataset.test_data.size())
print("test_labels:",test_dataset.test_labels.size())

train_data: torch.Size([60000, 28, 28])
train_labels: torch.Size([60000])
test_data: torch.Size([10000, 28, 28])
test_labels: torch.Size([10000])

还需要确定批次的尺寸,在神经网络训练中,batch_size 是指每次迭代训练时,模型同时处理的样本数量。它在训练过程中起到了几个重要作用:

  • 加速训练过程:通过同时处理多个样本,利用了现代计算机的并行计算能力,可以加速训练过程,尤其在使用GPU时。
  • 减少内存消耗:一次性加载整个训练集可能会占用大量内存,而将数据分批次加载可以降低内存消耗,使得在内存受限的环境中也能进行训练。
  • 提高模型泛化能力:在训练过程中,模型会根据每个batch的数据调整权重,而不是依赖于整个训练集。这样可以提高模型对不同样本的泛化能力。
  • 避免陷入局部极小值:随机选择的小批量样本可以帮助模型避免陷入局部极小值。
  • 增加噪声鲁棒性:在每次迭代中,模型只看到了一个小样本,这可以看作一种随机噪声,有助于提高模型的鲁棒性。
  • 方便在线学习:对于在线学习任务,可以动态地加载新的数据批次,而不需要重新训练整个模型。

总的来说,合理选择合适的batch_size可以使训练过程更加高效、稳定,并且能够提高模型的泛化能力。然而,过大的batch_size可能会导致内存溢出或训练速度变慢,过小的batch_size可能会导致模型收敛困难。因此,选择合适的batch_size需要在实践中进行调试和优化。

print("批次的尺寸:",train_loader.batch_size)
print("load_train_data:",train_loader.dataset.train_data.shape)
print("load_train_labels:",train_loader.dataset.train_labels.shape)

 

批次的尺寸: 100
load_train_data: torch.Size([60000, 28, 28])
load_train_labels: torch.Size([60000])

 从输出结果中,可以看到原始数据集和数据打乱按照批次读取的数据集的总行数是一样的,实际操作中train_loader以及test_loader将作为神经网络的输入数据源。

二、定义神经网络

在前面的文章中已经带着大家搭建过好几遍神经网络了,注意初始化网络和对应的输入层,隐藏层和输出层。

import torch.nn as nn
import torchinput_size = 784 #mnist的像素为28*28
hidden_size = 500
num_classes = 10#输出为10个类别分别对应于0~9#创建神经网络模型
class Neural_net(nn.Module):
#初始化函数,接受自定义输入特征的维数,隐含层特征维数以及输出层特征维数def __init__(self,input_num,hidden_size,out_put):super(Neural_net,self).__init__()self.layer1 = nn.Linear(input_num,hidden_size) #从输入到隐藏层的线性处理self.layer2 = nn.Linear(hidden_size,out_put) #从隐藏层到输出层的线性处理def forward(self,x):x = self.layer1(x) #输入层到隐藏层的线性计算x = torch.relu(x) #隐藏层激活x = self.layer2(x) #输出层,注意,输出层直接接lossreturn xnet = Neural_net(input_size,hidden_size,num_classes)
print(net)
Neural_net((layer1): Linear(in_features=784, out_features=500, bias=True)(layer2): Linear(in_features=500, out_features=10, bias=True)
)

 super(Neural_net, self).init() 是 Python 中用于调用父类的方法或属性的一种方式。在这里,Neural_net 是你定义的神经网络模型的类名,它继承了 nn.Module 类,而 nn.Module 是 PyTorch 中用于构建神经网络模型的基类。也就是说,你的神经网络模型会继承 nn.Module 的所有属性和方法,这样你可以在 Neural_net 类中使用 nn.Module 中定义的各种功能,比如添加神经网络层、指定损失函数等。

三、训练模型

只有要注意一下Variable,之前的文章中又提到过Variable。

Variable是PyTorch早期版本(0.4版本之前)中用于构建计算图的抽象,它包含了data、grad和grad_fn等属性,可以用于构建计算图,并在反向传播时自动计算梯度。但从PyTorch 0.4版本开始,Variable被官方废弃,而Tensor直接支持了自动求导功能,不再需要显式地创建Variable。

因此,Autograd是PyTorch实现自动求导的核心机制,而Variable是早期版本中用于构建计算图的一种抽象,现在已经被Tensor所取代。 Autograd会自动追踪Tensor上的操作,并在需要时计算梯度,从而实现反向传播。

#optimization
import numpy as np
from torchvision import transformslearning_rate = 1e-3 #学习率
num_epoches = 5
criterion =nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(),lr = learning_rate) #随机梯度下降for epoch in range(num_epoches):print('current epoch = %d' % epoch)for i ,(images,labels) in enumerate(train_loader,0):images=images.view(-1,28*28)outputs = net(images) #将数据集传入网络做前向计算labels = torch.tensor(labels, dtype=torch.long)loss = criterion(outputs, labels) #计算lossoptimizer.zero_grad() #在做反向传播之前先清楚下网络状态loss.backward() #Loss反向传播optimizer.step() #更新参数if i % 100 == 0:print('current loss = %.5f' % loss.item())print('finished training')

 

current epoch = 1
current loss = 0.27720
current loss = 0.23612
current loss = 0.39341
current loss = 0.24683
current loss = 0.18913
current loss = 0.31647
current loss = 0.28518
current loss = 0.18053
current loss = 0.34957
current loss = 0.31319
current epoch = 2
current loss = 0.15138
current loss = 0.30887
current loss = 0.24257
current loss = 0.46326
current loss = 0.30790
current loss = 0.17516
current loss = 0.32319
current loss = 0.32325
current loss = 0.32066
current loss = 0.24271

 四、准确度测试

各层的权重通过随机梯度下降法更新Loss之后,针对测试集数字分类的准确率:

#prediction
total = 0
correct =0 
acc_list_test = []
for images,labels in test_loader:images=images.view(-1,28*28)outputs = net(images) #将数据集传入网络做前向计算_,predicts = torch.max(outputs.data,1)total += labels.size(0)correct += (predicts == labels).sum()acc_list_test.append(100 * correct / total)print('Accuracy = %.2f'%(100 * correct / total))
plt.plot(acc_list_test)
plt.xlabel('Epoch')
plt.ylabel('Accuracy On TestSet')
plt.show()

 

 

点关注,防走丢,如有纰漏之处,请留言指教,非常感谢

以上就是本期全部内容。我是fanstuck ,有问题大家随时留言讨论 ,我们下期见。


相关文章:

PyTorch实战:实现MNIST手写数字识别

前言 PyTorch可以说是三大主流框架中最适合初学者学习的了,相较于其他主流框架,PyTorch的简单易用性使其成为初学者们的首选。这样我想要强调的一点是,框架可以类比为编程语言,仅为我们实现项目效果的工具,也就是我们…...

【计算机网络】深入理解TCP协议二(连接管理机制、WAIT_TIME、滑动窗口、流量控制、拥塞控制)

TCP协议 1.连接管理机制2.再谈WAIT_TIME状态2.1理解WAIT_TIME状态2.2解决TIME_WAIT状态引起的bind失败的方法2.3监听套接字listen第二个参数介绍 3.滑动窗口3.1介绍3.2丢包情况分析 4.流量控制5.拥塞控制5.1介绍5.2慢启动 6.捎带应答、延时应答 1.连接管理机制 正常情况下&…...

springboot整合sentinel完成限流

1、直入正题,下载sentinel的jar包 1.1 直接到Sentinel官网里的releases下即可下载最新版本,Sentinel官方下载地址,直接下载jar包即可。不过慢,可能下载不下来 1.2 可以去gitee去下载jar包 1.3 下载完成后,进行打包…...

signal(SIGPIPE, SIG_IGN)

linux查看signal常见信号。 [rootplatform:]# kill -l1) HUP2) INT3) QUIT4) ILL5) TRAP6) ABRT7) BUS8) FPE9) KILL 10) USR1 11) SEGV 12) USR2 13) PIPE 14) ALRM 15) TERM 16) STKFLT 17) CHLD 18) CONT 19) STOP 20) TSTP 21) TTIN 22) TTOU 23) URG 24) XCPU 25) XFSZ 2…...

GAN学习笔记

1.原始的GAN 1.1原始的损失函数 1.1.1写法1参考1,参考2 1.1.2 写法2 where, G Generator D Discriminator Pdata(x) distribution of real data P(z) distribution of generator x sample from Pdata(x) z sample from P(z) D(x) Discriminator network G…...

layui框架学习(45: 工具集模块)

layui的工具集模块util支持固定条、倒计时等组件,同时提供辅助函数处理时间数据、字符转义、批量事件处理等操作。   util模块中的fixbar函数支持设置固定条(2.7版本的帮助文档中叫固定块),是指固定在页面一侧的工具条元素&…...

车道检测:Decoupling the Curve Modeling and Pavement Regression for Lane Detection

论文作者:Wencheng Han,Jianbing Shen 作者单位:University of Macau 论文链接:http://arxiv.org/abs/2309.10533v1 内容简介: 1)方向:车道检测 2)应用:车道检测 3&#xff09…...

【扩散生成模型】Diffusion Generative Models

提出扩散模型思想的论文: 《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》理解 扩散模型综述: “扩散模型”首篇综述论文分类汇总,谷歌&北大最新研究 理论推导、代码实现: What are Diffusion Models?…...

美联储加息步伐“暂停”!BTC凌晨力守27000美元!

美东时间9月20日下午,美联储宣布放缓加息步伐,将联邦基金利率目标维持在5.25%至5.50%的区间不变,保持在22年来的最高点,符合市场预期。 在最新的FOMC声明中,美联储表示最近的指标表明,经济活动一直在稳步扩…...

微信小程序与idea后端如何进行数据交互

交互使用的其实就是调用的req.get(url)方法 进行路径访问,你要先保证自己的springboot项目已经成功运行了: 如下: 如何交互的? 微信小程序:如下为index.js页面 在onLoad()事件中调用方法Project.findAllCities() 要…...

Java 学习路线分享 maven 是什么?

Maven 是一款基于 Java 平台的项目管理和整合工具,它将项目的开发和管理过程抽象成一个项目对象模型(POM)。开发人员只需要做一些简单的配置,Maven 就可以自动完成项目的编译、测试、打包、发布以及部署等工作。 Maven 是使用 Ja…...

实战演练 | Navicat 常用功能之转储与运行 SQL 文件

数据库管理工作中,"转储 SQL 文件"和"运行 SQL 文件"是两个极为常见操作。一般来说,用户使用数据库管理工具或命令行工具来完成。Navicat 管理开发工具中的“转储 SQL 文件”和“运行 SQL 文件”功能具有直观易用的界面、多种文件格…...

MySQL的备份与恢复

备份与恢复 一、备份1.1 数据备份的必要性1.2 数据备份分类1.2.1 物理备份1.2.2 逻辑备份 1.3 数据库备份策略1.4 常用的备份方法和工具1.5 数据库上云迁移 二、MySQL完全备份2.1 简介2.2 物理冷备份与恢复2.2.1 物理冷备份2.2.2 解压恢复 2.3 mysqldump备份与恢复1&#xff09…...

Python中的函数未定义的错误

前言: 嗨喽~大家好呀,这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 通过这个解释,我们将了解当Python程序显示类似NameError: name ‘’ is not defined的错误时,即使该函数存在于脚本中&…...

AG35学习笔记(二):安装编译SDK、CMakeLists编译app、Scons编译server

目录 一、概述二、安装SDK2.1 网盘SDK - 权限不够2.2 bj41 - 需要交叉source2.3 mullen - relocate_sdk.py路径有误 三、编译SDK3.1 /bin/sh: 1: gcc: not found3.2 curses.h: No such file or directory 四、CMakeLists - 编译app4.1 cmake - 项目构建4.2 make - 项目编译4.3 …...

多台服务器sessionId共享

目录 多台服务器sessionId共享解决方案:ASP.NET Core 参考代码(NET 7):登录处理登录(请求)过滤器过滤器使用BaseController 多台服务器sessionId共享 session id是服务器首次与浏览器创建连接时,生成的id值,存入浏览器…...

如何在Gazebo中实现多机器人编队仿真

文章目录 前言一、仿真前的配置二、实现步骤1.检查PC和台式机是否通讯成功2.编队中对单个机器人进行独立的控制3、对机器人进行编队控制 前言 实现在gazebo仿真环境中添加多个机器人后,接下来进行编队控制,对具体的实现过程进行记录。 一、仿真前的配置…...

迅为iTOP-iMX6QPLUS-Android6.0下uboot添加网卡驱动

本文档介绍在 iTOP-iMX6Q 和 iTOP-iMX6Q-PLUS 安卓 6.0 的 uboot 上添加网卡驱 动,添加完网卡驱动以后,uboot 就可以正常使用网络了。 1 具体步骤 1.1 修改 mx6sabre_common.h 文件 在 iTOP-iMX6_android6.0.1 源码目录下输入以下命令,打…...

sql server 触发器的使用

看数据库下的所有触发器及状态 SELECT a.name 数据表名 , sysobjects.name AS 触发器名 , sysobjects.crdate AS 创建时间 , sysobjects.info , sysobjects.status FROM sysobjects LEFT JOIN ( SELECT * FROM sysobjects WHERE xtype U ) AS a ON sysobjects.parent_obj a.…...

使用亚马逊云服务器在 G4 实例上运行 Android 应用程序

随着 Android 应用程序和游戏变得越来越丰富,其中有些甚至比 PC 上的软件更易于使用和娱乐,因此许多人希望能够在云上运行 Android 游戏或应用程序,而在 EC2 实例上运行 Android 的解决方案可以让开发人员更轻松地测试和运行 Android 应用程序…...

Direct3D融合技术

该技术能使我们将当前要进行光栅化的像素的颜色与先前已已光栅化并处于同一位置的像素的颜色进行合成,即将正在处理的图元颜色值与存储在后台缓存中的像素颜色值进行合成(混合),利用该技术我们可得到各种各样的效果,尤其是透明效果。 在融合…...

【计算机网络】信号处理接口 Signal API(1)

收发信号思想是 Linux 程序设计特性之一,一个信号可以认为是一种软中断,通过用来向进程通知异步事件。 本文讲述的 信号处理内容源自 Linux man。本文主要对各 API 进行详细介绍,从而更好的理解信号编程。 signal 遵循 C11,POSIX.…...

贝叶斯滤波计算4d毫米波聚类目标动静属性

机器人学中有些问题是二值问题,对于这种二值问题的概率评估问题可以用二值贝叶斯滤波器binary Bayes filter来解决的。比如机器人前方有一个门,机器人想判断这个门是开是关。这个二值状态是固定的,并不会随着测量数据变量的改变而改变。就像门…...

华为hcie认证考试怎么考?

华为HCIE认证考试怎么考? 前文腾科也说了HCIE认证考试的难度会比较大,具体是难在哪里呢?华为HCIE认证的考试需要考一门笔试,笔试主要是单选、多选、判断、填空、拖拽这几个题型,考试时长一般是一个半小时,…...

vue +element 删除按钮操作 (删除单个数据 +删除页码处理 )

1.配置接口deleteItemById: "/api/goods/deleteItemById", //删除商品操作 2.get请求接口 // 删除接口 后台给我们 返iddeleteItemById(params){return axios.get(base.deleteItemById,{params})}3.异步请求接口 async deleteItemById(id){let res await this.…...

更新GitLab上的项目

更新GitLab上的项目 如有需要,请参考这篇:上传项目到gitlab上 1.打开终端,进入到本地项目的根目录。 2.如果你还没有将远程GitLab仓库添加到本地项目,你可以使用以下命令: 比如: git remote add origin …...

K8S群集调度

K8S群集调度 一、调度约束1.概述2.Pod 启动典型创建过程(工作机制 )3.调度过程4.Predicate 的常见的算法5.常见的优先级选项6.指定调度节点: 二、亲和性1.节点亲和性2.Pod 亲和性3.键值运算关系4.示例5.Pod亲和性与反亲和性6.使用 Pod 反亲和…...

完美解决Echarts X坐标轴下方文字最后一个字体加粗颜色加深的问题

之前用Echarts画图的时候,X坐标轴最后一个字存在自动加粗的问题。也是在网上找过解决办法没有找到,后面自己研究明白了后,在某篇文章下评论了如何解决。但是好像大家没有看评论的习惯,所以单独拿出来一篇文章,希望能给…...

WebGL 计算平行光、环境光下的漫反射光颜色

目录 光照原理 光源类型 平行光 点光源 环境光 反射类型 漫反射 漫反射光颜色 计算公式 环境反射 环境反射光颜色 表面的反射光颜色(漫反射和环境反射同时存在时)计算公式 平行光下的漫反射 根据光线和法线方向计算入射角θ(以便…...

解决SpringMVC在JSP页面取不到ModelAndView中数据

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl 问题描述 ModelAndView携带数据跳转到指定JSP页面后在该页面通过EL表达式取不到原本存放在ModelAndView中的数据。 问题原因 在IDEA中创建Maven工程时web.xml中默认的约束…...

vatage wordpress主题/线上网络推广怎么做

Controller控制台Vuser运行的状态 Down:没有运行Pending:挂起Init:初始化Ready:准备就绪Run:正在运行Rendezvous:正在集结Passed:运行通过Failed:运行失败Error:出现故障Gradual Exting:逐一退出Exiting:退出Stopped:停止运行 Star Scenario:运行场景Stop:强制停止运行场景Rest…...

网站开发合同样本/免费涨1000粉丝网站

最近在想自己的文章有些是不是写的太难以理解了呢.........竟然好多人看了还是会直接问我很多问题....... 其实PID哈靠自己想像就能自己写出来自己的代码,也许是网上的讲的太过的高深什么积分微分,搞的晕头转向,本来这么实用的想法为什么偏偏说的那么的琢磨不透......感觉那些人…...

恒一信息深圳网站建设公司2/百度快照推广

如今,在各行各业作业生产中,都能看到工业连接器、插头插座的身影,它能够传输高速、高容量和高精度的信号和电力,具有防水、防尘、抗震动、抗干扰等特性,被广泛应用在工业控制、通讯、医疗、交通、航空、军事等领域&…...

以投资思维做网站/国内最新新闻大事

web标准的存放数据的范围有: pageContext域,request域,session域,application域(ServletContext)。 struts2自己又定义了一个容器来存放数据,即:ActionContext。 ActionContext是个Map集合,它持…...

网站建设费无形资产摊销/b站推广在哪里

因为工作原因&#xff0c;需要对ecshop二次开发&#xff0c;顺便记录一下对ecshop源代码的一些分析&#xff1a;首先是init.php文件&#xff0c;这个文件在ecshop每个页面都会 调用到&#xff0c;习惯就先分析它&#xff1a;<?php/** * ECSHOP 前台公用文件*///防止非法调用…...

wordpress 注入 实战/百度关键词指数查询工具

■ 定义 text-align属性用于设置元素内文本内容的水平对齐方式 ■ 使用说明 属性值(固定值)&#xff1a; ▶ left&#xff1a;左对齐(默认值) ▶ right&#xff1a;右对齐 ▶ center&#xff1a;居中对齐 ■ 示例 h1 {text-align: left; }h2 {text-align: center; }h3 {text-al…...