当前位置: 首页 > news >正文

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

一、网络构建

1.1 问题导入

如图所示,数字五的图片作为输入,layer01层为输入层,layer02层为隐藏层,找出每列最大值对应索引为输出层。根据下图给出的网络结构搭建本案例用到的全连接神经网络
在这里插入图片描述

1.2 手写字数据集MINST

如图所示,MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。数据集也被嵌入到sklearn和pytorch框架中可以直接调用。这里我们默认已经安装了pytorch框架。不会使用的这里简单介绍一下。
大家可以用按住win+R键,打开运行窗口,输入cmd。
在这里插入图片描述
输入cmd,回车后,会显示如下。
在这里插入图片描述
输入以下的命令,可以看看自己的电脑的显卡是不是NVIDIA。如果是AMD的,那么就安装cpu的吧,毕竟CUDA内核,只支持NVIDIA的显卡。

#AMD显卡
pip install pytorch-cpu
#NVIDIA显卡
pip install pytorch
#如果速度慢的话,可以加入清华源的链接
pip install pytorch-cpu -i https://pypi.tuna.tsinghua.edu.cn/simple/
#NVIDIA显卡
pip install pytorch -i https://pypi.tuna.tsinghua.edu.cn/simple/

这样就完成了,仍然存在问题的小伙伴,可以参考小程序员推荐的这个up主的教程pytorch保姆级教程。
这里我们输出几张图片和对应的标签。作为对数据集的了解,也方便我们针对性的设计网络结构,做到心中有数。
在这里插入图片描述

二、采用Pytorch框架编写全连接神经网络代码实现手写字识别

2.1 导入必要的包

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

2.2 定义一些数据预处理操作

pipline=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])

2.3 下载数据集(训练集vs测试集)

train_dataset=datasets.MNIST('./data',train=True,transform=pipline,download=True)
test_dataset=datasets.MNIST('./data',train=False,transform=pipline,download=True)
print(len(train_dataset))
print(len(test_dataset))

60000
10000

2.4 分批加载训练集和测试集中的数据到内存里

train_loader=DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=32)

2.5 可视化数据集中的数据,做到心中有数

import matplotlib.pyplot as plt
examples=enumerate(train_loader)
_,(example_data,example_label)=next(examples)
print(example_data.shape)
for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(example_data[i][0],cmap='gray')
#     plt.title('Ground Truth:{}'.format(example_label[i]))plt.title(f'Ground Truth:{example_label[i]}')

torch.Size([32, 1, 28, 28])
在这里插入图片描述

2.6 网络模型设计(有时也称为网络模型搭建)

class Net(nn.Module):def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):super(Net,self).__init__()self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.Sigmoid())self.layer3=nn.Linear(n_hidden_2,out_dim)    def forward(self,x):x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)return x
model=Net(28*28,300,100,10)
model

以下结果来自Jupyter Notebook
Net(
(layer1): Sequential(
(0): Linear(in_features=784, out_features=300, bias=True)
(1): ReLU(inplace=True)
)
(layer2): Sequential(
(0): Linear(in_features=300, out_features=100, bias=True)
(1): Sigmoid()
)
(layer3): Linear(in_features=100, out_features=10, bias=True)
)

import torch.optim as optim
criterion=nn.CrossEntropyLoss()   #选用Pytorch中nn模块封装好的交叉熵损失函数
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)  #选用随机梯度下降法(SGD)作为本模型的梯度下降法
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')   #确定代码运行设备究竟实在GPU还是CPU上跑
model.to(device)

2.7 训练网络模型

losses=[]
acces=[]eval_losses=[]
eval_acces=[]#训练轮数---epochfor epoch in range(10):train_loss=0train_acc=0model.train()   #启用网络模型隐藏层中的dropout和BN(批归一化)操作if epoch%5==0:   #控制训练轮数间隔optimizer.param_groups[0]['lr']*=0.9    #动态调整学习率for img,label in train_loader:img=img.to(device)   #将训练图片写到设备里label=label.to(device)  #将图片类别写到设备里img=img.view(img.size(0),-1)out=model(img)   #调用前向传播函数得到预测值loss=criterion(out,label)   #计算预测值和真实值的损失optimizer.zero_grad()  #在新一轮反向传播开始前,清空上一轮反向传播得到的梯度loss.backward()  #把上一部得到的损失执行反向传播,得到新的网络模型参数(权值)optimizer.step()   #把上一部得到的新的权值更新到网络模型里#在前面前向传播和反向传播的额基础上,计算一些训练算法性能指标train_loss+=loss.item()  #记录反向传播每一轮得到的损失_,pred=out.max(1)   #得到图片的预测类别num_correct=(pred==label).sum().item()   #获取预测正确的样本数量acc=num_correct/img.shape[0]      #每一批次的正确率train_acc+=acc       #每一轮次的额正确率losses.append(train_loss/len(train_loader))    #所有轮次训练完之后总的损失acces.append(train_acc/len(train_loader))     #所有轮次训练完之后总的正确率

2.8 在测试集上测试网络模型,检验模型效果

eval_loss=0
eval_acc=0
model.eval()   #继续沿用BN操作,但是不再使用dropout操作with torch.no_grad():for img,label in test_loader:img=img.to(device)label=label.to(device)img=img.view(img.size(0),-1)out=model(img)loss=criterion(out,label)eval_loss+=loss.item()   #记录每一批次的损失_,pred=out.max(1)num_correct=(pred==label).sum().item()acc=num_correct/img.shape[0]   #记录每一批次的准确率eval_acc+=acc     #记录每一轮的准确率eval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader), eval_loss / len(test_loader), eval_acc / len(test_loader)))

epoch: 0, Train Loss: 1.1721, Train Acc: 0.6760, Test Loss: 0.4936, Test Acc: 0.8692
epoch: 1, Train Loss: 0.4093, Train Acc: 0.8866, Test Loss: 0.3368, Test Acc: 0.9020
epoch: 2, Train Loss: 0.3192, Train Acc: 0.9084, Test Loss: 0.2884, Test Acc: 0.9171
epoch: 3, Train Loss: 0.2755, Train Acc: 0.9194, Test Loss: 0.2552, Test Acc: 0.9271
epoch: 4, Train Loss: 0.2429, Train Acc: 0.9290, Test Loss: 0.2251, Test Acc: 0.9349
epoch: 5, Train Loss: 0.2160, Train Acc: 0.9367, Test Loss: 0.2001, Test Acc: 0.9405
epoch: 6, Train Loss: 0.1945, Train Acc: 0.9433, Test Loss: 0.1854, Test Acc: 0.9447
epoch: 7, Train Loss: 0.1761, Train Acc: 0.9494, Test Loss: 0.1716, Test Acc: 0.9504
epoch: 8, Train Loss: 0.1601, Train Acc: 0.9540, Test Loss: 0.1597, Test Acc: 0.9527
epoch: 9, Train Loss: 0.1468, Train Acc: 0.9572, Test Loss: 0.1434, Test Acc: 0.9567

2.10可视化训练及测试的损失值

plt.title('Train Loss')
plt.plot(np.arange(len(losses)),losses);
plt.legend(['Train Loss'],loc='upper right')                   

损失函数的结果:
在这里插入图片描述

三、代码文件

小程序员将代码文件和相关素材整理到了百度网盘里,因为文件大小基本不大,大家也不用担心限速问题。后期小程序员有能力的话,将在gitee或者github上上传相关素材。
链接:https://pan.baidu.com/s/1Ce14ZQYEYWJxhpNEP1ERhg?pwd=7mvf
提取码:7mvf

相关文章:

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现) 一、网络构建 1.1 问题导入 如图所示,数字五的图片作为输入,layer01层为输入层,layer02层为隐藏层,找出每列最大值对应索引为输…...

预算砍砍砍,IT运维如何降本增效

疫情短暂过去,一个乐观的共识正在蔓延:2023年的互联网,绝对不会比2022年更差。 “降本”是过去一年许多公司的核心策略,营销大幅缩水、亏损业务大量撤裁,以及层出不穷的裁员消息。而2023年在可预期的经济复苏下&#…...

10.Jenkins用tags的方式自动发布java应用

Jenkins用tags的方式自动发布java应用1.配置jenkins,告诉jenkins,jdk的安装目录,maven的安装目录2.构建一个maven项目指定构建参数,选择Git Paramete在源码管理中,填写我们git项目的地址,调用变量构建前执行…...

2023新华为OD机试题 - 相同数字的积木游戏 1(JavaScript)

相同数字的积木游戏 1 题目 小华和小薇一起通过玩积木游戏学习数学。 他们有很多积木,每个积木块上都有一个数字, 积木块上的数字可能相同。 小华随机拿一些积木挨着排成一排,请小薇找到这排积木中数字相同且所处位置最远的 2 块积木块,计算他们的距离。 小薇请你帮忙替她…...

重构之改善既有代码的设计(一)

1.1 何为重构,为何重构 第一个定义是名词形式: 重构(名词):对软件内部结构的一种调整,目的是在不改变「软件可察行为」前提下,提高其可理解性,降低修改成本。 「重构」的另一个用…...

Kotlin data class 数据类用法

实验数据 {"code":1,"message":"成功","data":{"name":"周杰轮","gender":1} }kotlin数据类使用方便提供如下内部Api: equals()/hashCode()对 toString() componentN()按声明顺序与属性相…...

随笔-老子不想牺牲了

18年来到这个项目组,当时只有8个人,包括经常不在的架构师和经理。当时的工位在西区1栋A座,办公桌很宽敞。随着项目的发展,入职的人越来越多,项目的工位也是几经搬迁。基本上每次搬迁时,我的工位都是挑剩下的…...

三种查找Windows10环境变量的方法

文章目录一.在设置中查看二. 在我的电脑中查看三. 在资源管理器里查看一.在设置中查看 在系统中搜索设置 打开设置,在设置功能里,点击第一项 系统 在系统功能里,左侧菜单找到关于 在关于的相关设置里可以看到高级系统设置 点击高级系…...

STM32单片机DS18B20测温程序源代码

OLED液晶屏电路接口DS18B20电路接口STM32单片机DS18B20测温程序源代码#include "sys.h"#define LED_RED PBout(12)#define LED_GREEN PBout(13)#define LED_YELLOW PBout(14)#define LED_BLUE PBout(15)#define DS18B20_IO_IN() {GPIOA->CRL&0XFFFFFFF0;GPIOA…...

java日志查看工具finder介绍

目录 一、finder介绍 二、单节点部署 1、服务器需要安装Tomcat,以2.82.16.35为例 2、进入Tomcat下目录webapps下,创建FIND目录,进入FIDN目录 3、下载findweb插件,解压缩 4、登录页面,配置 5、添加日志路径 三、…...

手写现代前端框架diff算法-前端面试进阶

前言 在前端工程上,日益复杂的今天,性能优化已经成为必不可少的环境。前端需要从每一个细节的问题去优化。那么如何更优,当然与他的如何怎么实现的有关。比如key为什么不能使用index呢?为什么不使用随机数呢?答案当然…...

【半监督医学图像分割 2022 MICCAI】CLLE 论文翻译

文章目录【半监督医学图像分割 2022 MICCAI】CLLE 论文翻译摘要1. 简介2. 方法2.1 半监督框架概述2.2 监督局部对比学习2.3 下采样和块划分3. 实验4. 结论【半监督医学图像分割 2022 MICCAI】CLLE 论文翻译 论文题目:Semi-supervised Contrastive Learning for Labe…...

vivo官网App模块化开发方案-ModularDevTool

作者:vivo 互联网客户端团队- Wang Zhenyu 本文主要讲述了Android客户端模块化开发的痛点及解决方案,详细讲解了方案的实现思路和具体实现方法。 说明:本工具基于vivo互联网客户端团队内部开源的编译管理工具开发。 一、背景 现在客户端的业…...

Python基础-数据类型之数字类型

变量中的变量值是用来存储事物状态的,事物的状态分成不同的种类(例如:人的姓名、年龄,身高、职位、工资等),因此变量值有多种不同的数据类型。 age 18 # 用整型记录年龄 salary 3.1 # 用浮点型记录…...

基于Web的6个完美3D图形WebGL库

现代前端、游戏和Web开发正是WebGL可以转化为数字杰作的东西。使用GPU绘制在浏览器屏幕上生成的矢量元素,WebGL创建交互式Web图形,从而获得用户体验。视觉元素的质量和复杂性使该工具在HTML或CSS等其他方法中脱颖而出。WebGL基础WebGL不是一个图形套件。…...

界面组件DevExpress Reporting v22.2 - 增强的Web报表组件UI

DevExpress Reporting是.NET Framework下功能完善的报表平台,它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集,包括数据透视表、图表,因此您可以构建无与伦比、信息清晰的报表。DevExpress Reporting v22.2版本已正式发布&…...

初学vector

目录 string的收尾 拷贝构造的现代写法: 浅拷贝: 拷贝构造的现代写法: swap函数: 内置类型有拷贝构造和赋值重载吗? 完善拷贝构造的现代写法: 赋值重载的现代写法: 更精简的现代写法&…...

Windows10 安装wsl2、Ubuntu相关操作

Windows10 安装wsl2、Ubuntu相关操作 安装wsl2 查看本机windows版本: 键盘上按下winr,输入winver,查看系统版本。必须运行 windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 windows 11。满足版本要求后&#xf…...

SpringBoot简单使用MongoDB

MongoDB介绍 SpringBoot简单使用MongoDB 一、配置步骤 1、application.yml 2、pom 3、entity 4、mapper 二、案例代码使用 1、库 前期准备上一篇安装MongoDB地址http://t.csdn.cn/G4oYJ 跟关系型数据库概念对比 Mysql MongoDB Database(数据库) Datab…...

Oracle Data Guard 角色转换(Role Transitions)

查询视图V$DATABASE的DATABASE_ROLE列可以看到数据库当前的角色。 1.角色转换介绍 Oracle Data Guard让你可以使用SQL语句或者通过Oracle Data Guard broker界面来动态更改数据库的角色,Oracle Data Guard支持以下的角色转换: 1&#xff0…...

opencv的TrackBar控件

大家好,我是csdn的博主:lqj_本人 这是我的个人博客主页: lqj_本人的博客_CSDN博客-微信小程序,前端,python领域博主lqj_本人擅长微信小程序,前端,python,等方面的知识https://blog.csdn.net/lbcyllqj?spm1011.2415.3001.5343哔哩哔哩欢迎关注…...

关于基线长度对双天线GNSS测姿精度的影响

文章目录一、GNSS测姿原理1. 载波相位双差求解基线向量2. GNSS姿态角表示二、基线长度对GNSS测姿精度的影响三、GNSS定向产品精度描述实例四、参考文献在GNSS定向模块或者板卡的指标参数中,我们一般会看到航向的测量精度和基线的长度相关。在实际使用,用…...

口交换机睿易 RG-NBS1826GC 24 口

接口形态不将就,标配光纤接口传输性能不将就,标配千兆上联口和大缓存设计端口数量不将就,8/16/24 三种选择楼宇对讲交换机不将就,保证开锁指令品质服务不将就,监控专用交换机接口形态不将就,标配光纤接口非…...

如何在Excel中向下拉列表中添加条件

在Excel中向下拉列表中添加条件 创建矩阵型数据集创建下拉列表创建第一个下拉列表创建第二个下拉列表你可以使用Microsoft Excel下拉列表来显示一个简单的列表,尽管有时需要更多的控制。假设你的人员分散在四个地区:北部、南部、东部和西部。你希望按地区与人员合作,而不是与…...

自定义bean 加载到spring IOC容器中

自定义bean加载到spring容器中的两种方式: 1.在类上添加注解Controller、RestController(本质是Controller)、Service、Repository、Component2.使用Configuration和Bean 这篇文章主要介绍第二种方式原理(因为在实际使用中&#…...

[python入门㊻] - python装饰器和类的装饰器

目录 ❤ python装饰器介绍 ❤ 什么是装饰器 ❤ 装饰器的流程 ❤ 定义装饰器时通常会涉及以下3个函数 无参装饰器 有参装饰器 多重装饰器 ❤ 装饰器的用法(闭包) ❤ 装饰器语法糖 ❤ 时间计时器 ❤ 装饰器中wraps作用 不使用wraps装饰器 使用wraps装饰器解…...

企业级信息系统开发学习1.1 初识Spring——采用Spring配置文件管理Bean

文章目录一、Spring容器演示——采用Spring配置文件管理Bean(一)创建Maven项目(二)添加Spring依赖(三)创建杀龙任务类(四)创建勇敢骑士类(五)采用传统方式让勇…...

CSS盒子模型

盒子模型 CSS三大特性 继承性、层叠性、优先级 优先级比较 继承 < 通配符选择器 < 标签选择器 < 类选择器 < id选择器 < 行内样式 < !important 注意&#xff1a;!important不能提升继承的优先级&#xff0c;只要是继承优先级最低 复合选择器权重叠加计…...

Python基础学习笔记 —— 数据结构与算法

数据结构与算法1 数据结构基础1.1 数组1.2 链表1.3 队列1.4 栈1.5 二叉树2 排序算法2.1 冒泡排序2.2 快速排序2.3 &#xff08;简单&#xff09;选择排序2.4 堆排序2.5 &#xff08;直接&#xff09;插入排序3 查找3.1 二分查找1 数据结构基础 本章所需相关基础知识&#xff1a…...

笔记本连接wifi,浏览器访问页面,显示访问被拒绝

打开chrome、edge浏览器访问第1个第2个页面正常&#xff0c;后面再打开页面显示异常。 但手机连接正常&#xff0c;笔记本连接异常&#xff0c;起初完全没有怀疑是wifi问题 以为用了vpn软件问题&#xff0c;认为中了病毒。杀毒&#xff0c;并没有中毒。 1、关闭vpn代理&#…...

男女做那个全面视频网站/互联网广告推广好做吗

实验三用Excel软件进行绘图一、实验目的使学生较熟练地掌握资料整理和统计图表的绘制方法。要求会使用EXCEL绘制的图表、图形&#xff0c;以及公式的编辑和计算。二、实验器具计算机三、实验要求每位同学一台计算机独立完成操作&#xff0c;并结合习题按照操作情况写出。四、实…...

asp制作网站教程/推广普通话作文

文章目录前言1. 题目2. 题目分析3. 四个python内置函数3.1 lower()方法3.2 ord()方法3.3 bin()方法3.4 count()方法4. 代码前言 再一次感受到了python的强大&#xff0c;这个题这么复杂的操作&#xff0c;python只用了4个函数&#xff0c;12行就搞定了。这次总结一下这4个函数…...

如何做公众号小说网站赚钱/俄罗斯引擎搜索

CNC加工中心的高精高效&#xff0c;安全是前提。安全生产离不开优秀的车间管理&#xff0c;设备的精良保养以及丰富的加工经验。 1.预先开机 正式加工前可以进行开机空转&#xff0c;让CNC加工中心主轴空转几分钟&#xff0c;可以让主轴的轴承充分润滑&#xff0c;减少加工误…...

网站开发的系统测试/广告推广 精准引流

问题详细描述假设&#xff0c;我已经在我的本地系统中保存了一个JSON文件&#xff0c;同时创建了一个Javascript文件&#xff0c;以便读取JSON文件并打印数据。JSON文件如下&#xff1a;{"resource":"A","literals":["B","C"…...

做服装外贸的网站/网络销售工作靠谱吗

pg_stat_activity是PostgreSQL原生工具&#xff0c;官方说明如下&#xff1a;The pg_stat_activity view will have one row per server process, showing information related to the current activity of that process&#xff08;该pg_stat_activity视图将为每个服务器进程显…...

做网站的困难/外贸营销型网站制作

to_date 函数&#xff1a;TO_DATE( string1 [, format_mask] [, nls_language] ) 后面两个函数为可选 &#xff0c;意思将字符串类型转换为时间类型 &#xff0c; 可以自定义时间格式举例&#xff1a;获取日期 to_date(2004-09-01,YYYY-MM-DD) &#xff0c;to_date(20020315, …...