CNN卷积网络实现MNIST数据集手写数字识别
步骤一:加载MNIST数据集
train_data = MNIST(root='./data',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=64)
# 测试数据集
test_data = MNIST(root='./data',train=False,download=False,transform=transforms.ToTensor())
test_loader = DataLoader(test_data,shuffle=False,batch_size=64)
首先,通过MNIST
类创建了train_data
对象,指定了数据集的路径root='./data'
,并且将数据集标记为训练集train=True
。download=False
表示不自动从网络上下载数据集,而是使用已经下载好的数据集。我是之前自己已经下载过该数据集所以这里填的是False,如果之前没有下载的话就要填True。下面测试集也是一样。transforms.ToTensor()
将数据转换为张量形式。
然后,通过DataLoader
类创建了train_loader
对象,指定了使用train_data
作为数据源。shuffle=True
表示在每个epoch开始时,将数据打乱顺序。batch_size=64
表示每次抓取64个样本。
接下来,同样的步骤也被用来创建了测试集的数据加载器test_loader
。不同的是,这里将数据集标记为测试集train=False
,并且shuffle=False
表示不需要打乱顺序。
加载完的数据集存在MNIST文件夹的raw文件夹下内容如下:
其中t10k-images-idx3-ubyte是测试集的图像,t10k-labels-idx3-ubyte是测试集的标签。train-images-idx3-ubyte是训练集的图像,train-labels-idx1-ubyte是训练集的标签。
存下来的这些数据集是二进制的形式,可以通过下面的代码(1.py)读取:
"""
Created on Sat Jul 27 15:26:39 2024@author: wangyiyuan
"""
# 导入包
import struct
import numpy as np
from PIL import Imageclass MnistParser:# 加载图像def load_image(self, file_path):# 读取二进制数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>iiii'offset = 0# 读取头文件magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('图片数量:%d,图片行数:%d,图片列数:%d'%(images_number,rows_number,columns_number))# 处理数据image_size = rows_number * columns_numberfmt_data = '>'+str(image_size)+'B'offset = offset + struct.calcsize(fmt_head)# 读取数据images = np.empty((images_number,rows_number,columns_number))for i in range(images_number):images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1) % 10000 == 0:print('> 已读取:%d张图片'%(i+1))# 返回数据return images_number,rows_number,columns_number,images# 加载标签def load_labels(self, file_path):# 读取数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>ii'offset = 0# 读取头文件magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('标签数:%d'%(items_number))# 处理数据fmt_data = '>B'offset = offset + struct.calcsize(fmt_head)# 读取数据labels = np.empty((items_number))for i in range(items_number):labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1)%10000 == 0:print('> 已读取:%d个标签'%(i+1))# 返回数据return items_number,labels# 图片可视化def visualaztion(self, images, labels, path):d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}for i in range(images.__len__()):im = Image.fromarray(np.uint8(images[i]))im.save(path + "%d_%d.png"%(labels[i], d[labels[i]]))d[labels[i]] += 1# im.show()if (i+1)%10000 == 0:print('> 已保存:%d个图片'%(i+1))# 保存为图片格式
def change_and_save():mnist = MnistParser()trainImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(trainImageFile)trainLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(trainLabelFile)mnist.visualaztion(images, labels, "./images/train/")testImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(testImageFile)testLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(testLabelFile)mnist.visualaztion(images, labels, "./images/test/")# 测试
if __name__ == '__main__':change_and_save()
将这个1.py文件和下载好的数据集放在同一个文件夹下:
新建一个文件夹images,在文件夹images里面新建两个文件夹分别叫test和train。
运行完可以发现train和test里的内容如下:
步骤二:建立模型
class Model(nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = nn.Linear(784,256)self.linear2 = nn.Linear(256,64)self.linear3 = nn.Linear(64,10) # 10个手写数字对应的10个输出def forward(self,x):x = x.view(-1,784) # 变形x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))# x = torch.relu(self.linear3(x))return x
这里是建立了一个神经网络模型类(Model)。这个模型有三个线性层(linear1、linear2、linear3)。输入维度为784(因为每一张图片的大小是28*28=784),输出维度为256、64、10(因为有十个类)。forward函数定义了模型的前向传播过程,其中x.view(-1, 784)将输入张量x变形为(batch_size, 784)的大小。然后经过三个线性层和relu激活函数进行运算,最后返回输出结果x。
步骤三:训练模型
model = Model()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,相当于Softmax+Log+NllLoss
optimizer = torch.optim.SGD(model.parameters(),0.8) # 第一个参数是初始化参数值,第二个参数是学习率# 模型训练
# def train():
for index,data in enumerate(train_loader):input,target = data # input为输入数据,target为标签optimizer.zero_grad() # 梯度清零y_predict = model(input) # 模型预测loss = criterion(y_predict,target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数if index % 100 == 0: # 每一百次保存一次模型,打印损失torch.save(model.state_dict(),"./model/model.pkl") # 保存模型torch.save(optimizer.state_dict(),"./model/optimizer.pkl")print("损失值为:%.2f" % loss.item())
首先创建了一个模型对象model,一个损失函数对象criterion和一个优化器对象optimizer。然后使用一个for循环遍历训练数据集train_loader,每次取出一个batch的数据。接着将优化器的梯度清零,然后使用模型前向传播得到预测结果y_predict,计算损失值loss,然后进行反向传播和参数更新。每训练100个batch,保存模型和优化器的参数,并打印当前的损失值。
步骤四:保存模型参数
if os.path.exists('./model/model.pkl'):model.load_state_dict(torch.load("./model/model.pkl")) # 加载保存模型的参数
在当前文件夹下新建一个名叫model的文件夹。保存步骤三中训练完模型的参数。
步骤五:检验模型
correct = 0 # 正确预测的个数total = 0 # 总数with torch.no_grad(): # 测试不用计算梯度for data in test_loader:input,target = dataoutput=model(input) # output输出10个预测取值,其中最大的即为预测的数probability,predict=torch.max(output.data,dim=1) # 返回一个元组,第一个为最大概率值,第二个为最大值的下标total += target.size(0) # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item() # predict和target均为(batch_size,1)的矩阵,sum()求出相等的个数print("准确率为:%.2f" % (correct / total))
参数说明:
correct
:记录正确预测的个数total
:记录总样本数test_loader
:测试集的数据加载器input
:输入数据target
:目标标签output
:模型的输出结果probability
:最大概率值predict
:最大值的下标
过程:
- 使用
torch.no_grad()
包装测试过程,表示不需要计算梯度 - 遍历测试集中的每个数据,获取输入数据和目标标签
- 将输入数据输入模型,得到模型的输出结果
- 使用
torch.max()
函数返回预测结果中的最大概率值和最大值的下标 - 更新总数和正确预测的个数
- 最后计算并输出准确率。
步骤六:检测自己的手写数据
if __name__ == '__main__':# 自定义测试image = Image.open('C:/Users/wangyiyuan/Desktop/20201116160729670.jpg') # 读取自定义手写图片image = image.resize((28,28)) # 裁剪尺寸为28*28image = image.convert('L') # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1,1,28,28)output = model(image)probability,predict=torch.max(output.data,dim=1)print("此手写图片值为:%d,其最大概率为:%.2f" % (predict[0],probability))plt.title('此手写图片值为:{}'.format((int(predict))),fontname="SimHei")plt.imshow(image.squeeze())plt.show()
这里的C:/Users/wangyiyuan/Desktop/20201116160729670.jpg是我自己从网上找的的手写图片。这段代码意思如下:
- 打开并读取一张手写图片,图片的路径为'C:/Users/wangyiyuan/Desktop/20201116160729670.jpg'。
- 调整图片尺寸为28x28。
- 将图片转换为灰度图像,以便后续处理。
- 使用transforms.ToTensor()将图片转换为PyTorch张量。
- 调整图片尺寸为(1, 1, 28, 28)以适应模型的输入要求。
- 将处理后的图片输入模型,获取预测输出。
- 通过torch.max函数获得输出中的最大值及其索引,即预测的数字和其概率。
- 打印预测的数字和概率。
- 在图像上显示预测结果和手写图片。
- 展示图像。
步骤七:结果展示
我的原图是:
测试得到的结果为:
损失值为:4.16
损失值为:0.93
损失值为:0.31
损失值为:0.19
损失值为:0.24
损失值为:0.15
损失值为:0.13
损失值为:0.11
损失值为:0.18
损失值为:0.02
此手写图片值为:2,其最大概率为:6.57
相关文章:
CNN卷积网络实现MNIST数据集手写数字识别
步骤一:加载MNIST数据集 train_data MNIST(root./data,trainTrue,downloadFalse,transformtransforms.ToTensor()) train_loader DataLoader(train_data,shuffleTrue,batch_size64) # 测试数据集 test_data MNIST(root./data,trainFalse,downloadFalse,transfor…...
深入理解Java中的时间处理与时区管理
在Java开发中,时间处理和时区管理是常见的需求,特别是在全球化应用中。Java 8引入了新的时间API(java.time包),使时间处理变得更加直观和高效。本文将详细介绍Java中的时间处理与时区管理,通过丰富的代码示…...
虚拟机windows server创建域
目录 准备工作 一、新建域控制器 二、提升为域控制器添加新林 三、新建组织单位(OU),用户 四、将计算机加域 五、在域控中管理计算机 六、在域控中配置组策略 七、域内计算机验证组策略配置 准备工作 安装域前,如果有DNS…...
Java 集合框架:Java 中的 Set 集合(HashSet LinkedHashSet TreeSet)特点与实现解析
大家好,我是栗筝i,这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 017 篇文章,在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验,并希望进一步完善自己对整个 Java 技术体系来充实自…...
springboot智能健康管理平台-计算机毕业设计源码57256
摘要 在当今社会,人们越来越重视健康饮食和健康管理。借助SpringBoot框架和MySQL数据库的支持,开发智能健康管理平台成为可能。该平台结合了小程序技术的便利性和SpringBoot框架的快速开发能力,为用户提供了便捷的健康管理解决方案。 通过智能…...
LetterBox图像预处理方法
LetterBox图像预处理方法就是要将不同分辨率的图像转换成固定分辨率,比如v8输入网络的固定分辨率为6406403,因此这里分享一下默认情况下对训练集、验证集和测试图片做的letterBox的方法。 1.LetterBox-Train 对于训练集,默认输入网络的图像尺寸为640640,假设有一张7201280…...
C++第五篇 类和对象(下) 初始化列表
目录 1.再探构造函数 2.类型转换 3.static成员 4.友元 friiend 1.再探构造函数 (1).之前我们实现构造函数时,初始化成员变量主要使用函数体内赋值,构造函数初始化还有一种方式,就是初始化列表,初始化列表的使用方式是以一个冒…...
C#中的通信
上位机应用开发-串口通信1、基于C#的串口通信对象:SerialPort 2、字段属性 PortName:获取或设置通信端口 BaudRate:获取或设置串行波特率-DataBits:获取或设置每个字节的标准数据位长度 Parity:获取或设置奇偶校验检查协仪I-StopBits;获取或设置每个字节的标准停止位数 3、…...
CVE-2022-21663: WordPress <5.8.3 版本对象注入漏洞深入分析
引言 在网络安全领域,技术的研究与讨论是不断进步的动力。本文针对WordPress的一个对象注入漏洞进行分析,旨在分享技术细节并提醒安全的重要性。特别强调:本文内容仅限技术研究,严禁用于非法目的。 漏洞背景 继WordPress CVE-2…...
C语言笔试题(三)
本专栏通过整理各专业方向的面试资料并咨询业界相关人士,整合不同方向的面试资料,希望能为您的面试道路点亮一盏灯! 1 简单题 如何声明一个二维数组? 答案: int arr[3][4];解析: 二维数组可以看作数组的数组。 union和struct…...
minio笔记之windows下安装使用
minio安装使用 去官网下载安装包启动访问管理平台创建桶创建用户、资源授权访问访问策略创建创建用户创建accessKey,用于应用程序开发 去官网下载安装包 直接安装即可 启动 设置密码 set MINIO_ROOT_USERadmin set MINIO_ROOT_PASSWORD12345678 cd到安装目录 mi…...
代码随想录算法训练营day31 | 56. 合并区间、738.单调递增的数字
碎碎念:加油 参考:代码随想录 56. 合并区间 题目链接 56. 合并区间 思想 这道题的核心还是判断重叠区间,本题和之前做过的452. 用最少数量的箭引爆气球、435. 无重叠区间的区别在于判断出重叠区间之后的操作,本题需要做的是合…...
利用 Python 制作图片轮播应用
在这篇博客中,我将向大家展示如何使用 xPython 创建一个图片轮播应用。这个应用能够从指定文件夹中加载图片,定时轮播,并提供按钮来保存当前图片到收藏夹或仅轮播收藏夹中的图片。我们还将实现退出按钮和全屏显示的功能。 C:\pythoncode\new\…...
报表系统之Cube.js
Cube.js 是一个开源的分析框架,专为构建数据应用和分析工具而设计。它的主要目的是简化和加速构建复杂的分析和数据可视化应用。以下是对 Cube.js 的详细介绍: 核心功能和特点 1. 多数据源支持 Cube.js 支持从多个数据源中提取数据,包括 SQ…...
代码随想录算法训练营第45天
115.不同的子序列 但相对于刚讲过 392.判断子序列,本题 就有难度了 ,感受一下本题和 392.判断子序列 的区别。 代码随想录 class Solution {public int numDistinct(String s, String t) {int lenS s.length();int lenT t.length();int[][] dp new …...
solidity合约创建
合约可以通过使用new关键字来创建其他合约的实例。 这个过程会执行被创建合约的构造函数(如果存在的话),并返回一个指向新创建合约的地址的引用。 这种方式允许智能合约动态地在区块链上部署新合约,并与它们交互。 通过 new 创…...
队列---循环队列实现
循环队列详解 概述 循环队列是一种基于数组实现的队列数据结构,其中队列的队首和队尾是通过模运算连接起来形成一个逻辑上的环形结构。这样可以有效地利用数组的空间,避免出现“假溢出”的情况。 结构体定义 循环队列的结构体定义如下: …...
【视频讲解】后端增删改查接口有什么用?
B站视频地址 B站视频地址 前言 “后端增删改查接口有什么用”,其实这句话可以拆解为下面3个问题。 接口是什么意思?后端接口是什么意思?后端接口中的增删改查接口有什么用? 1、接口 概念:接口的概念在不同的领域中…...
双指针hard题
[LeetCode]4. Median of Two Sorted Arrays 中文 - YouTube 依赖merge sort和priorityqueue的废物 正式变身山景城一姐小迷妹✪ω✪ 寻找正序数组中位数 class Solution {public double findMedianSortedArrays(int[] nums1, int[] nums2) {int len1 nums1.length;int len2 …...
前端实现【 批量任务调度管理器 】demo优化
一、前提介绍 我在前文实现过一个【批量任务调度管理器】的 demo,能实现简单的任务批量并发分组,过滤等操作。但是还有很多优化空间,所以查找一些优化的库, 主要想优化两个方面, 上篇提到的: 针对 3&…...
【数据结构】包装类和泛型
🎉欢迎大家收看,请多多支持🌹 🥰关注小哇,和我一起成长🚀个人主页🚀 ⭐在更专栏Java ⭐数据结构 ⭐已更专栏有C语言、计算机网络⭐ 👑目录 包装类🌙 ⭐基本类型对应的包…...
浅学爬虫-数据存储
在数据爬取完成后,我们需要将数据存储起来,以便于后续的分析和处理。常见的数据存储方式包括存储到CSV文件和存储到数据库。下面我们详细介绍如何实现这些存储方式。 存储到CSV CSV(Comma-Separated Values)文件是一种常用的文本…...
十六、maven git-快速上手(智慧云教育平台)
🌻🌻 目录 一、概述及项目管理工具介绍1.1 项目介绍1.2 maven 介绍及其配置1.2.1 maven 介绍1.2.2 maven 下载与配置 1.3 pom 中常见标签的使用1.4 后端项目环境的搭建1.5 Git 简介1.6 Git 的基本使用1.6.1 码云的注册与仓库创建1.6.2 上传代码到码云仓库…...
chrome/edge浏览器插件开发入门与加载使用
同学们可以私信我加入学习群! 正文开始 前言一、插件与普通前端项目二、开发插件——manifest.json三、插件使用edge浏览器中使用/加载插件chrome浏览器中使用/加载插件 总结 前言 chrome插件的出现,初衷可能是为了方便用户更好地控制浏览器,…...
【完美解决】 TypeError: ‘str’ object does not support item assignment
【完美解决】 TypeError: ‘str’ object does not support item assignment 在Python编程中,遇到TypeError: str object does not support item assignment这样的错误通常意味着你试图修改字符串中的某个字符,但字符串是不可变类型,不支持这…...
Android SurfaceFlinger——渲染开始帧(四十三)
通过前面的文章我们介绍了 SurfaceFlinger 图层合成的整体流程,已经对应步骤的前五步,这里我们开始介绍帧渲染流程的第一步——开始帧。 1.更新输出设备的色彩配置文件2.更新与合成相关的状态3.计划合成帧图层4.写入合成状态5.设置颜色矩阵6.开始帧7.准备帧数据以进行显示(异…...
fastadmin搜索栏实现某字段动态下拉搜索
记录:fastadmin搜索栏实现某字段动态下拉搜索 方式一:使用selectpicker组件,可多选 { field: travel_agency, title:__(Travel_agency),addClass:"selectpicker", operate:"IN",data:"multiple", searchList:…...
.NET未来路在何方?
简述 在软件开发的漫长旅程中,将代码打包成可执行的EXE文件是一项必不可少的技能。它不仅能够保护源代码,还能为用户提供便捷的安装体验。但手动打包过程繁琐且容易出错,自动化打包成为了开发者的福音。 在软件开发的浩瀚星空中,.…...
Vue开发环境搭建
文章目录 引言I 安装NVM1.1 Windows系统安装NVM,实现Node.js多版本管理1.2 配置下载镜像1.3 NVM常用操作命令II VUE项目的基础配置2.1 制定不同的环境配置2.2 正式环境隐藏日志2.3 vscode常用插件引言 开发工具: node.js 、npm 开发编辑器:vscode 开发框架:VUE I 安装NVM…...
【数据结构初阶】详解:实现循环队列、用栈实现队列、用队列实现栈
文章目录 一、循环队列1、题目简述2、方法讲解2.1、了解tail的指向2.2、了解空间是如何利用的2.3、如何判断队列是否为空(假溢出问题)?2.4、实现代码 二、用栈实现队列1、题目简述2、方法讲解2.1、讲解2.2、实现代码 三、用队列实现栈1、题目…...
如何自己做网站手机软件/福州seo博客
通过最近对 Flutter 开发的大致了解,感受最深的简单概括就是:Widget 就是一切外加组合和响应式,我们开发的界面,通过组合其他的 Widget 来实现,当界面发生变化时,不会像我们原来 iOS 或者 Andriod 开发一样…...
做众筹网站怎么赚钱/他达拉非片正确服用方法
随着经济增长变缓和全球竞争日益激烈,中国加工工业在性能提高、成本降低和工业升级方面已面临诸多挑战。中国加工工业如何才能实现从“中国制造”成功转型至“中国智造”?我们来看看霍尼韦尔过程控制部全球副总裁兼中国区总经理王春文先生如何回答这一问…...
it运维网/网站优化公司哪家效果好
android studio运行程序的时候,列表里找不到夜神模拟器,当然,模拟器是开着的。 解决方法: 1.桌面上找到夜神模拟器,右键-打开文件所在的位置,比如我的是F:\Program Files\Nox\bin 2.打开cmd命令窗口&…...
武汉网站建设网站推广/关键词权重
在pb125版本中,创建数据窗口时如果不是直接选择目标表,而是随便选了一个表,然后修改数据窗口sql语句,这种情况下,dw header区的enabled默认时不启用的,最直接的影响就是无法点击header去排序。...
企业网址格式/百度seo费用
Skype - 更新 - 3.8.4.182 [Skype - 官方网站] http://www.skype.com/http://skype.tom.com/download/[Skype - 下载] http://www.skype.com/intl/en/download/[Skype - 当前版本] 3.8.4.182...
做家政有什么网站做推广好/深圳seo网络优化公司
Python函数装饰器 装饰器(Decorators)是 Python 的一个重要部分。简单地说:他们是修改其他函数的功能的函数。他们有助于让我们的代码更简短,也更Pythonic(Python范儿)。大多数初学者不知道在哪儿使用它们,所以我将要…...