CNN卷积网络实现MNIST数据集手写数字识别
步骤一:加载MNIST数据集
train_data = MNIST(root='./data',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=64)
# 测试数据集
test_data = MNIST(root='./data',train=False,download=False,transform=transforms.ToTensor())
test_loader = DataLoader(test_data,shuffle=False,batch_size=64)
首先,通过MNIST类创建了train_data对象,指定了数据集的路径root='./data',并且将数据集标记为训练集train=True。download=False表示不自动从网络上下载数据集,而是使用已经下载好的数据集。我是之前自己已经下载过该数据集所以这里填的是False,如果之前没有下载的话就要填True。下面测试集也是一样。transforms.ToTensor()将数据转换为张量形式。
然后,通过DataLoader类创建了train_loader对象,指定了使用train_data作为数据源。shuffle=True表示在每个epoch开始时,将数据打乱顺序。batch_size=64表示每次抓取64个样本。
接下来,同样的步骤也被用来创建了测试集的数据加载器test_loader。不同的是,这里将数据集标记为测试集train=False,并且shuffle=False表示不需要打乱顺序。
加载完的数据集存在MNIST文件夹的raw文件夹下内容如下:
其中t10k-images-idx3-ubyte是测试集的图像,t10k-labels-idx3-ubyte是测试集的标签。train-images-idx3-ubyte是训练集的图像,train-labels-idx1-ubyte是训练集的标签。
存下来的这些数据集是二进制的形式,可以通过下面的代码(1.py)读取:
"""
Created on Sat Jul 27 15:26:39 2024@author: wangyiyuan
"""
# 导入包
import struct
import numpy as np
from PIL import Imageclass MnistParser:# 加载图像def load_image(self, file_path):# 读取二进制数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>iiii'offset = 0# 读取头文件magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('图片数量:%d,图片行数:%d,图片列数:%d'%(images_number,rows_number,columns_number))# 处理数据image_size = rows_number * columns_numberfmt_data = '>'+str(image_size)+'B'offset = offset + struct.calcsize(fmt_head)# 读取数据images = np.empty((images_number,rows_number,columns_number))for i in range(images_number):images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1) % 10000 == 0:print('> 已读取:%d张图片'%(i+1))# 返回数据return images_number,rows_number,columns_number,images# 加载标签def load_labels(self, file_path):# 读取数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>ii'offset = 0# 读取头文件magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('标签数:%d'%(items_number))# 处理数据fmt_data = '>B'offset = offset + struct.calcsize(fmt_head)# 读取数据labels = np.empty((items_number))for i in range(items_number):labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1)%10000 == 0:print('> 已读取:%d个标签'%(i+1))# 返回数据return items_number,labels# 图片可视化def visualaztion(self, images, labels, path):d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}for i in range(images.__len__()):im = Image.fromarray(np.uint8(images[i]))im.save(path + "%d_%d.png"%(labels[i], d[labels[i]]))d[labels[i]] += 1# im.show()if (i+1)%10000 == 0:print('> 已保存:%d个图片'%(i+1))# 保存为图片格式
def change_and_save():mnist = MnistParser()trainImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(trainImageFile)trainLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(trainLabelFile)mnist.visualaztion(images, labels, "./images/train/")testImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(testImageFile)testLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(testLabelFile)mnist.visualaztion(images, labels, "./images/test/")# 测试
if __name__ == '__main__':change_and_save()
将这个1.py文件和下载好的数据集放在同一个文件夹下:
新建一个文件夹images,在文件夹images里面新建两个文件夹分别叫test和train。

运行完可以发现train和test里的内容如下:

步骤二:建立模型
class Model(nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = nn.Linear(784,256)self.linear2 = nn.Linear(256,64)self.linear3 = nn.Linear(64,10) # 10个手写数字对应的10个输出def forward(self,x):x = x.view(-1,784) # 变形x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))# x = torch.relu(self.linear3(x))return x
这里是建立了一个神经网络模型类(Model)。这个模型有三个线性层(linear1、linear2、linear3)。输入维度为784(因为每一张图片的大小是28*28=784),输出维度为256、64、10(因为有十个类)。forward函数定义了模型的前向传播过程,其中x.view(-1, 784)将输入张量x变形为(batch_size, 784)的大小。然后经过三个线性层和relu激活函数进行运算,最后返回输出结果x。
步骤三:训练模型
model = Model()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,相当于Softmax+Log+NllLoss
optimizer = torch.optim.SGD(model.parameters(),0.8) # 第一个参数是初始化参数值,第二个参数是学习率# 模型训练
# def train():
for index,data in enumerate(train_loader):input,target = data # input为输入数据,target为标签optimizer.zero_grad() # 梯度清零y_predict = model(input) # 模型预测loss = criterion(y_predict,target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数if index % 100 == 0: # 每一百次保存一次模型,打印损失torch.save(model.state_dict(),"./model/model.pkl") # 保存模型torch.save(optimizer.state_dict(),"./model/optimizer.pkl")print("损失值为:%.2f" % loss.item())
首先创建了一个模型对象model,一个损失函数对象criterion和一个优化器对象optimizer。然后使用一个for循环遍历训练数据集train_loader,每次取出一个batch的数据。接着将优化器的梯度清零,然后使用模型前向传播得到预测结果y_predict,计算损失值loss,然后进行反向传播和参数更新。每训练100个batch,保存模型和优化器的参数,并打印当前的损失值。
步骤四:保存模型参数
if os.path.exists('./model/model.pkl'):model.load_state_dict(torch.load("./model/model.pkl")) # 加载保存模型的参数
在当前文件夹下新建一个名叫model的文件夹。保存步骤三中训练完模型的参数。
步骤五:检验模型
correct = 0 # 正确预测的个数total = 0 # 总数with torch.no_grad(): # 测试不用计算梯度for data in test_loader:input,target = dataoutput=model(input) # output输出10个预测取值,其中最大的即为预测的数probability,predict=torch.max(output.data,dim=1) # 返回一个元组,第一个为最大概率值,第二个为最大值的下标total += target.size(0) # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item() # predict和target均为(batch_size,1)的矩阵,sum()求出相等的个数print("准确率为:%.2f" % (correct / total))
参数说明:
correct:记录正确预测的个数total:记录总样本数test_loader:测试集的数据加载器input:输入数据target:目标标签output:模型的输出结果probability:最大概率值predict:最大值的下标
过程:
- 使用
torch.no_grad()包装测试过程,表示不需要计算梯度 - 遍历测试集中的每个数据,获取输入数据和目标标签
- 将输入数据输入模型,得到模型的输出结果
- 使用
torch.max()函数返回预测结果中的最大概率值和最大值的下标 - 更新总数和正确预测的个数
- 最后计算并输出准确率。
步骤六:检测自己的手写数据
if __name__ == '__main__':# 自定义测试image = Image.open('C:/Users/wangyiyuan/Desktop/20201116160729670.jpg') # 读取自定义手写图片image = image.resize((28,28)) # 裁剪尺寸为28*28image = image.convert('L') # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1,1,28,28)output = model(image)probability,predict=torch.max(output.data,dim=1)print("此手写图片值为:%d,其最大概率为:%.2f" % (predict[0],probability))plt.title('此手写图片值为:{}'.format((int(predict))),fontname="SimHei")plt.imshow(image.squeeze())plt.show()
这里的C:/Users/wangyiyuan/Desktop/20201116160729670.jpg是我自己从网上找的的手写图片。这段代码意思如下:
- 打开并读取一张手写图片,图片的路径为'C:/Users/wangyiyuan/Desktop/20201116160729670.jpg'。
- 调整图片尺寸为28x28。
- 将图片转换为灰度图像,以便后续处理。
- 使用transforms.ToTensor()将图片转换为PyTorch张量。
- 调整图片尺寸为(1, 1, 28, 28)以适应模型的输入要求。
- 将处理后的图片输入模型,获取预测输出。
- 通过torch.max函数获得输出中的最大值及其索引,即预测的数字和其概率。
- 打印预测的数字和概率。
- 在图像上显示预测结果和手写图片。
- 展示图像。
步骤七:结果展示
我的原图是:

测试得到的结果为:

损失值为:4.16
损失值为:0.93
损失值为:0.31
损失值为:0.19
损失值为:0.24
损失值为:0.15
损失值为:0.13
损失值为:0.11
损失值为:0.18
损失值为:0.02
此手写图片值为:2,其最大概率为:6.57
相关文章:
CNN卷积网络实现MNIST数据集手写数字识别
步骤一:加载MNIST数据集 train_data MNIST(root./data,trainTrue,downloadFalse,transformtransforms.ToTensor()) train_loader DataLoader(train_data,shuffleTrue,batch_size64) # 测试数据集 test_data MNIST(root./data,trainFalse,downloadFalse,transfor…...
深入理解Java中的时间处理与时区管理
在Java开发中,时间处理和时区管理是常见的需求,特别是在全球化应用中。Java 8引入了新的时间API(java.time包),使时间处理变得更加直观和高效。本文将详细介绍Java中的时间处理与时区管理,通过丰富的代码示…...
虚拟机windows server创建域
目录 准备工作 一、新建域控制器 二、提升为域控制器添加新林 三、新建组织单位(OU),用户 四、将计算机加域 五、在域控中管理计算机 六、在域控中配置组策略 七、域内计算机验证组策略配置 准备工作 安装域前,如果有DNS…...
Java 集合框架:Java 中的 Set 集合(HashSet LinkedHashSet TreeSet)特点与实现解析
大家好,我是栗筝i,这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 017 篇文章,在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验,并希望进一步完善自己对整个 Java 技术体系来充实自…...
springboot智能健康管理平台-计算机毕业设计源码57256
摘要 在当今社会,人们越来越重视健康饮食和健康管理。借助SpringBoot框架和MySQL数据库的支持,开发智能健康管理平台成为可能。该平台结合了小程序技术的便利性和SpringBoot框架的快速开发能力,为用户提供了便捷的健康管理解决方案。 通过智能…...
LetterBox图像预处理方法
LetterBox图像预处理方法就是要将不同分辨率的图像转换成固定分辨率,比如v8输入网络的固定分辨率为6406403,因此这里分享一下默认情况下对训练集、验证集和测试图片做的letterBox的方法。 1.LetterBox-Train 对于训练集,默认输入网络的图像尺寸为640640,假设有一张7201280…...
C++第五篇 类和对象(下) 初始化列表
目录 1.再探构造函数 2.类型转换 3.static成员 4.友元 friiend 1.再探构造函数 (1).之前我们实现构造函数时,初始化成员变量主要使用函数体内赋值,构造函数初始化还有一种方式,就是初始化列表,初始化列表的使用方式是以一个冒…...
C#中的通信
上位机应用开发-串口通信1、基于C#的串口通信对象:SerialPort 2、字段属性 PortName:获取或设置通信端口 BaudRate:获取或设置串行波特率-DataBits:获取或设置每个字节的标准数据位长度 Parity:获取或设置奇偶校验检查协仪I-StopBits;获取或设置每个字节的标准停止位数 3、…...
CVE-2022-21663: WordPress <5.8.3 版本对象注入漏洞深入分析
引言 在网络安全领域,技术的研究与讨论是不断进步的动力。本文针对WordPress的一个对象注入漏洞进行分析,旨在分享技术细节并提醒安全的重要性。特别强调:本文内容仅限技术研究,严禁用于非法目的。 漏洞背景 继WordPress CVE-2…...
C语言笔试题(三)
本专栏通过整理各专业方向的面试资料并咨询业界相关人士,整合不同方向的面试资料,希望能为您的面试道路点亮一盏灯! 1 简单题 如何声明一个二维数组? 答案: int arr[3][4];解析: 二维数组可以看作数组的数组。 union和struct…...
minio笔记之windows下安装使用
minio安装使用 去官网下载安装包启动访问管理平台创建桶创建用户、资源授权访问访问策略创建创建用户创建accessKey,用于应用程序开发 去官网下载安装包 直接安装即可 启动 设置密码 set MINIO_ROOT_USERadmin set MINIO_ROOT_PASSWORD12345678 cd到安装目录 mi…...
代码随想录算法训练营day31 | 56. 合并区间、738.单调递增的数字
碎碎念:加油 参考:代码随想录 56. 合并区间 题目链接 56. 合并区间 思想 这道题的核心还是判断重叠区间,本题和之前做过的452. 用最少数量的箭引爆气球、435. 无重叠区间的区别在于判断出重叠区间之后的操作,本题需要做的是合…...
利用 Python 制作图片轮播应用
在这篇博客中,我将向大家展示如何使用 xPython 创建一个图片轮播应用。这个应用能够从指定文件夹中加载图片,定时轮播,并提供按钮来保存当前图片到收藏夹或仅轮播收藏夹中的图片。我们还将实现退出按钮和全屏显示的功能。 C:\pythoncode\new\…...
报表系统之Cube.js
Cube.js 是一个开源的分析框架,专为构建数据应用和分析工具而设计。它的主要目的是简化和加速构建复杂的分析和数据可视化应用。以下是对 Cube.js 的详细介绍: 核心功能和特点 1. 多数据源支持 Cube.js 支持从多个数据源中提取数据,包括 SQ…...
代码随想录算法训练营第45天
115.不同的子序列 但相对于刚讲过 392.判断子序列,本题 就有难度了 ,感受一下本题和 392.判断子序列 的区别。 代码随想录 class Solution {public int numDistinct(String s, String t) {int lenS s.length();int lenT t.length();int[][] dp new …...
solidity合约创建
合约可以通过使用new关键字来创建其他合约的实例。 这个过程会执行被创建合约的构造函数(如果存在的话),并返回一个指向新创建合约的地址的引用。 这种方式允许智能合约动态地在区块链上部署新合约,并与它们交互。 通过 new 创…...
队列---循环队列实现
循环队列详解 概述 循环队列是一种基于数组实现的队列数据结构,其中队列的队首和队尾是通过模运算连接起来形成一个逻辑上的环形结构。这样可以有效地利用数组的空间,避免出现“假溢出”的情况。 结构体定义 循环队列的结构体定义如下: …...
【视频讲解】后端增删改查接口有什么用?
B站视频地址 B站视频地址 前言 “后端增删改查接口有什么用”,其实这句话可以拆解为下面3个问题。 接口是什么意思?后端接口是什么意思?后端接口中的增删改查接口有什么用? 1、接口 概念:接口的概念在不同的领域中…...
双指针hard题
[LeetCode]4. Median of Two Sorted Arrays 中文 - YouTube 依赖merge sort和priorityqueue的废物 正式变身山景城一姐小迷妹✪ω✪ 寻找正序数组中位数 class Solution {public double findMedianSortedArrays(int[] nums1, int[] nums2) {int len1 nums1.length;int len2 …...
前端实现【 批量任务调度管理器 】demo优化
一、前提介绍 我在前文实现过一个【批量任务调度管理器】的 demo,能实现简单的任务批量并发分组,过滤等操作。但是还有很多优化空间,所以查找一些优化的库, 主要想优化两个方面, 上篇提到的: 针对 3&…...
后进先出(LIFO)详解
LIFO 是 Last In, First Out 的缩写,中文译为后进先出。这是一种数据结构的工作原则,类似于一摞盘子或一叠书本: 最后放进去的元素最先出来 -想象往筒状容器里放盘子: (1)你放进的最后一个盘子(…...
树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频
使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...
黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 函数 这是最简单的方法,包含在stdlib.h头文件中: #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...
laravel8+vue3.0+element-plus搭建方法
创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...
uniapp手机号一键登录保姆级教程(包含前端和后端)
目录 前置条件创建uniapp项目并关联uniClound云空间开启一键登录模块并开通一键登录服务编写云函数并上传部署获取手机号流程(第一种) 前端直接调用云函数获取手机号(第三种)后台调用云函数获取手机号 错误码常见问题 前置条件 手机安装有sim卡手机开启…...
【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案
目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后,迭代器会失效,因为顺序迭代器在内存中是连续存储的,元素删除后,后续元素会前移。 但一些场景中,我们又需要在执行删除操作…...
python爬虫——气象数据爬取
一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用: 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests:发送 …...
